pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -1,25 +1,33 @@
1
+ import copy
2
+ import inspect
1
3
  import os.path as osp
2
4
  import random
3
- from typing import Callable, List, NamedTuple, Optional, Tuple, Union
5
+ import sys
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ List,
11
+ NamedTuple,
12
+ Optional,
13
+ Tuple,
14
+ Union,
15
+ )
4
16
 
5
17
  import torch
6
18
  from torch import Tensor
7
19
 
8
- from torch_geometric.inspector import split, type_repr
20
+ from torch_geometric.inspector import Parameter, Signature, eval_type, split
9
21
  from torch_geometric.template import module_from_template
10
22
 
11
23
 
12
24
  class Child(NamedTuple):
13
25
  name: str
14
- module: Callable
15
26
  param_names: List[str]
16
27
  return_names: List[str]
17
28
 
18
29
 
19
- def Sequential(
20
- input_args: str,
21
- modules: List[Union[Tuple[Callable, str], Callable]],
22
- ) -> torch.nn.Module:
30
+ class Sequential(torch.nn.Module):
23
31
  r"""An extension of the :class:`torch.nn.Sequential` container in order to
24
32
  define a sequential GNN model.
25
33
 
@@ -69,79 +77,197 @@ def Sequential(
69
77
 
70
78
  Args:
71
79
  input_args (str): The input arguments of the model.
72
- modules ([(str, Callable) or Callable]): A list of modules (with
80
+ modules ([(Callable, str) or Callable]): A list of modules (with
73
81
  optional function header definitions). Alternatively, an
74
82
  :obj:`OrderedDict` of modules (and function header definitions) can
75
83
  be passed.
76
84
  """
77
- signature = input_args.split('->')
78
- if len(signature) == 1:
79
- input_args = signature[0]
80
- return_type = type_repr(Tensor, globals())
81
- elif len(signature) == 2:
82
- input_args, return_type = signature[0], signature[1].strip()
83
- else:
84
- raise ValueError(f"Failed to parse arguments (got '{input_args}')")
85
-
86
- input_types = split(input_args, sep=',')
87
- if len(input_types) == 0:
88
- raise ValueError(f"Failed to parse arguments (got '{input_args}')")
89
-
90
- if not isinstance(modules, dict):
91
- modules = {f'module_{i}': module for i, module in enumerate(modules)}
92
- if len(modules) == 0:
93
- raise ValueError("'Sequential' expected a non-empty list of modules")
94
-
95
- children: List[Child] = []
96
- for i, (name, module) in enumerate(modules.items()):
97
- desc: Optional[str] = None
98
- if isinstance(module, (tuple, list)):
99
- if len(module) == 1:
100
- module = module[0]
101
- elif len(module) == 2:
102
- module, desc = module
103
- else:
104
- raise ValueError(f"Expected tuple of length 2 (got {module})")
105
-
106
- if i == 0 and desc is None:
107
- raise ValueError("Requires signature for first module")
108
- if not callable(module):
109
- raise ValueError(f"Expected callable module (got {module})")
110
- if desc is not None and not isinstance(desc, str):
111
- raise ValueError(f"Expected type hint representation (got {desc})")
112
-
113
- if desc is not None:
114
- signature = desc.split('->')
115
- if len(signature) != 2:
116
- raise ValueError(f"Failed to parse arguments (got '{desc}')")
117
- param_names = [v.strip() for v in signature[0].split(',')]
118
- return_names = [v.strip() for v in signature[1].split(',')]
119
- child = Child(name, module, param_names, return_names)
85
+ _children: List[Child]
86
+
87
+ def __init__(
88
+ self,
89
+ input_args: str,
90
+ modules: List[Union[Tuple[Callable, str], Callable]],
91
+ ) -> None:
92
+ super().__init__()
93
+
94
+ caller_path = inspect.stack()[1].filename
95
+ self._caller_module = osp.splitext(osp.basename(caller_path))[0]
96
+
97
+ _globals = copy.copy(globals())
98
+ _globals.update(sys.modules['__main__'].__dict__)
99
+ if self._caller_module in sys.modules:
100
+ _globals.update(sys.modules[self._caller_module].__dict__)
101
+
102
+ signature = input_args.split('->')
103
+ if len(signature) == 1:
104
+ args_repr = signature[0]
105
+ return_type_repr = 'Tensor'
106
+ return_type = Tensor
107
+ elif len(signature) == 2:
108
+ args_repr = signature[0]
109
+ return_type_repr = signature[1].strip()
110
+ return_type = eval_type(return_type_repr, _globals)
120
111
  else:
121
- param_names = children[-1].return_names
122
- child = Child(name, module, param_names, param_names)
123
-
124
- children.append(child)
125
-
126
- uid = '%06x' % random.randrange(16**6)
127
- root_dir = osp.dirname(osp.realpath(__file__))
128
- module = module_from_template(
129
- module_name=f'torch_geometric.nn.sequential_{uid}',
130
- template_path=osp.join(root_dir, 'sequential.jinja'),
131
- tmp_dirname='sequential',
132
- # Keyword arguments:
133
- input_types=input_types,
134
- return_type=return_type,
135
- children=children,
136
- )
137
-
138
- model = module.Sequential()
139
- model._module_names = [child.name for child in children]
140
- model._module_descs = [
141
- f"{', '.join(child.param_names)} -> {', '.join(child.return_names)}"
142
- for child in children
143
- ]
144
- for child in children:
145
- setattr(model, child.name, child.module)
146
-
147
- return model
112
+ raise ValueError(f"Failed to parse arguments (got '{input_args}')")
113
+
114
+ param_dict: Dict[str, Parameter] = {}
115
+ for arg in split(args_repr, sep=','):
116
+ signature = arg.split(':')
117
+ if len(signature) == 1:
118
+ name = signature[0].strip()
119
+ param_dict[name] = Parameter(
120
+ name=name,
121
+ type=Tensor,
122
+ type_repr='Tensor',
123
+ default=inspect._empty,
124
+ )
125
+ elif len(signature) == 2:
126
+ name = signature[0].strip()
127
+ param_dict[name] = Parameter(
128
+ name=name,
129
+ type=eval_type(signature[1].strip(), _globals),
130
+ type_repr=signature[1].strip(),
131
+ default=inspect._empty,
132
+ )
133
+ else:
134
+ raise ValueError(f"Failed to parse argument "
135
+ f"(got '{arg.strip()}')")
136
+
137
+ self.signature = Signature(param_dict, return_type, return_type_repr)
138
+
139
+ if not isinstance(modules, dict):
140
+ modules = {
141
+ f'module_{i}': module
142
+ for i, module in enumerate(modules)
143
+ }
144
+ if len(modules) == 0:
145
+ raise ValueError(f"'{self.__class__.__name__}' expects a "
146
+ f"non-empty list of modules")
147
+
148
+ self._children: List[Child] = []
149
+ for i, (name, module) in enumerate(modules.items()):
150
+ desc: Optional[str] = None
151
+ if isinstance(module, (tuple, list)):
152
+ if len(module) == 1:
153
+ module = module[0]
154
+ elif len(module) == 2:
155
+ module, desc = module
156
+ else:
157
+ raise ValueError(f"Expected tuple of length 2 "
158
+ f"(got {module})")
159
+
160
+ if i == 0 and desc is None:
161
+ raise ValueError("Signature for first module required")
162
+ if not callable(module):
163
+ raise ValueError(f"Expected callable module (got {module})")
164
+ if desc is not None and not isinstance(desc, str):
165
+ raise ValueError(f"Expected type hint representation "
166
+ f"(got {desc})")
167
+
168
+ if desc is not None:
169
+ signature = desc.split('->')
170
+ if len(signature) != 2:
171
+ raise ValueError(
172
+ f"Failed to parse arguments (got '{desc}')")
173
+ param_names = [v.strip() for v in signature[0].split(',')]
174
+ return_names = [v.strip() for v in signature[1].split(',')]
175
+ child = Child(name, param_names, return_names)
176
+ else:
177
+ param_names = self._children[-1].return_names
178
+ child = Child(name, param_names, param_names)
179
+
180
+ setattr(self, name, module)
181
+ self._children.append(child)
182
+
183
+ self._set_jittable_template()
184
+
185
+ def reset_parameters(self) -> None:
186
+ r"""Resets all learnable parameters of the module."""
187
+ for child in self._children:
188
+ module = getattr(self, child.name)
189
+ if hasattr(module, 'reset_parameters'):
190
+ module.reset_parameters()
191
+
192
+ def __len__(self) -> int:
193
+ return len(self._children)
194
+
195
+ def __getitem__(self, idx: int) -> torch.nn.Module:
196
+ return getattr(self, self._children[idx].name)
197
+
198
+ def __setstate__(self, data: Dict[str, Any]) -> None:
199
+ super().__setstate__(data)
200
+ self._set_jittable_template()
201
+
202
+ def __repr__(self) -> str:
203
+ module_descs = [
204
+ f"{', '.join(c.param_names)} -> {', '.join(c.return_names)}"
205
+ for c in self._children
206
+ ]
207
+ module_reprs = [
208
+ f' ({i}) - {self[i]}: {module_descs[i]}' for i in range(len(self))
209
+ ]
210
+ return '{}(\n{}\n)'.format(
211
+ self.__class__.__name__,
212
+ '\n'.join(module_reprs),
213
+ )
214
+
215
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
216
+ """""" # noqa: D419
217
+ value_dict = {
218
+ name: arg
219
+ for name, arg in zip(self.signature.param_dict.keys(), args)
220
+ }
221
+ for key, arg in kwargs.items():
222
+ if key in value_dict:
223
+ raise TypeError(f"'{self.__class__.__name__}' got multiple "
224
+ f"values for argument '{key}'")
225
+ value_dict[key] = arg
226
+
227
+ for child in self._children:
228
+ args = [value_dict[name] for name in child.param_names]
229
+ outs = getattr(self, child.name)(*args)
230
+ if len(child.return_names) == 1:
231
+ value_dict[child.return_names[0]] = outs
232
+ else:
233
+ for name, out in zip(child.return_names, outs):
234
+ value_dict[name] = out
235
+
236
+ return outs
237
+
238
+ # TorchScript Support #####################################################
239
+
240
+ def _set_jittable_template(self, raise_on_error: bool = False) -> None:
241
+ try: # Optimize `forward()` via `*.jinja` templates:
242
+ if ('forward' in self.__class__.__dict__ and
243
+ self.__class__.__dict__['forward'] != Sequential.forward):
244
+ raise ValueError("Cannot compile custom 'forward' method")
245
+
246
+ root_dir = osp.dirname(osp.realpath(__file__))
247
+ uid = '%06x' % random.randrange(16**6)
248
+ jinja_prefix = f'{self.__module__}_{self.__class__.__name__}_{uid}'
249
+ module = module_from_template(
250
+ module_name=jinja_prefix,
251
+ template_path=osp.join(root_dir, 'sequential.jinja'),
252
+ tmp_dirname='sequential',
253
+ # Keyword arguments:
254
+ modules=[self._caller_module],
255
+ signature=self.signature,
256
+ children=self._children,
257
+ )
258
+
259
+ self.forward = module.forward.__get__(self)
260
+
261
+ # NOTE We override `forward` on the class level here in order to
262
+ # support `torch.jit.trace` - this is generally dangerous to do,
263
+ # and limits `torch.jit.trace` to a single `Sequential` module:
264
+ self.__class__.forward = module.forward
265
+ except Exception as e: # pragma: no cover
266
+ if raise_on_error:
267
+ raise e
268
+
269
+ def __prepare_scriptable__(self) -> 'Sequential':
270
+ # Prevent type sharing when scripting `Sequential` modules:
271
+ type_store = torch.jit._recursive.concrete_type_store.type_store
272
+ type_store.pop(self.__class__, None)
273
+ return self
@@ -141,7 +141,7 @@ def get_shape(inputs: Any) -> str:
141
141
  def postprocess(info_list: List[dict]) -> List[dict]:
142
142
  for idx, info in enumerate(info_list):
143
143
  depth = info['depth']
144
- if idx > 0: # root module (0) is exclued
144
+ if idx > 0: # root module (0) is excluded
145
145
  if depth == 1:
146
146
  prefix = '├─'
147
147
  else:
@@ -272,7 +272,6 @@ class ToHeteroWithBasesTransformer(Transformer):
272
272
  args=(value, self.find_by_name('edge_offset_dict')),
273
273
  name=f'{value.name}__split')
274
274
 
275
- pass
276
275
  elif isinstance(value, Node):
277
276
  self.graph.inserting_before(node)
278
277
  return self.graph.create_node(
@@ -309,6 +308,24 @@ class ToHeteroWithBasesTransformer(Transformer):
309
308
  ###############################################################################
310
309
 
311
310
 
311
+ # We make use of a post-message computation hook to inject the
312
+ # basis re-weighting for each individual edge type.
313
+ # This currently requires us to set `conv.fuse = False`, which leads
314
+ # to a materialization of messages.
315
+ def hook(module, inputs, output):
316
+ assert isinstance(module._edge_type, Tensor)
317
+ if module._edge_type.size(0) != output.size(-2):
318
+ raise ValueError(
319
+ f"Number of messages ({output.size(0)}) does not match "
320
+ f"with the number of original edges "
321
+ f"({module._edge_type.size(0)}). Does your message "
322
+ f"passing layer create additional self-loops? Try to "
323
+ f"remove them via 'add_self_loops=False'")
324
+ weight = module.edge_type_weight.view(-1)[module._edge_type]
325
+ weight = weight.view([1] * (output.dim() - 2) + [-1, 1])
326
+ return weight * output
327
+
328
+
312
329
  class HeteroBasisConv(torch.nn.Module):
313
330
  # A wrapper layer that applies the basis-decomposition technique to a
314
331
  # heterogeneous graph.
@@ -319,23 +336,6 @@ class HeteroBasisConv(torch.nn.Module):
319
336
  self.num_relations = num_relations
320
337
  self.num_bases = num_bases
321
338
 
322
- # We make use of a post-message computation hook to inject the
323
- # basis re-weighting for each individual edge type.
324
- # This currently requires us to set `conv.fuse = False`, which leads
325
- # to a materialization of messages.
326
- def hook(module, inputs, output):
327
- assert isinstance(module._edge_type, Tensor)
328
- if module._edge_type.size(0) != output.size(-2):
329
- raise ValueError(
330
- f"Number of messages ({output.size(0)}) does not match "
331
- f"with the number of original edges "
332
- f"({module._edge_type.size(0)}). Does your message "
333
- f"passing layer create additional self-loops? Try to "
334
- f"remove them via 'add_self_loops=False'")
335
- weight = module.edge_type_weight.view(-1)[module._edge_type]
336
- weight = weight.view([1] * (output.dim() - 2) + [-1, 1])
337
- return weight * output
338
-
339
339
  params = list(module.parameters())
340
340
  device = params[0].device if len(params) > 0 else 'cpu'
341
341
 
@@ -468,7 +468,7 @@ def get_edge_type(
468
468
  ###############################################################################
469
469
 
470
470
  # These methods are used to group the individual type-wise components into a
471
- # unfied single representation.
471
+ # unified single representation.
472
472
 
473
473
 
474
474
  def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
@@ -20,6 +20,7 @@ from .utils import (
20
20
  get_gpu_memory_from_nvidia_smi,
21
21
  get_model_size,
22
22
  )
23
+ from .nvtx import nvtxit
23
24
 
24
25
  __all__ = [
25
26
  'profileit',
@@ -38,6 +39,7 @@ __all__ = [
38
39
  'get_gpu_memory_from_nvidia_smi',
39
40
  'get_gpu_memory_from_ipex',
40
41
  'benchmark',
42
+ 'nvtxit',
41
43
  ]
42
44
 
43
45
  classes = __all__
@@ -0,0 +1,66 @@
1
+ from functools import wraps
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ CUDA_PROFILE_STARTED = False
7
+
8
+
9
+ def begin_cuda_profile():
10
+ global CUDA_PROFILE_STARTED
11
+ prev_state = CUDA_PROFILE_STARTED
12
+ if prev_state is False:
13
+ CUDA_PROFILE_STARTED = True
14
+ torch.cuda.cudart().cudaProfilerStart()
15
+ return prev_state
16
+
17
+
18
+ def end_cuda_profile(prev_state: bool):
19
+ global CUDA_PROFILE_STARTED
20
+ CUDA_PROFILE_STARTED = prev_state
21
+ if prev_state is False:
22
+ torch.cuda.cudart().cudaProfilerStop()
23
+
24
+
25
+ def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
26
+ n_iters: Optional[int] = None):
27
+ """Enables NVTX profiling for a function.
28
+
29
+ Args:
30
+ name (Optional[str], optional): Name to give the reference frame for
31
+ the function being wrapped. Defaults to the name of the
32
+ function in code.
33
+ n_warmups (int, optional): Number of iters to call that function
34
+ before starting. Defaults to 0.
35
+ n_iters (Optional[int], optional): Number of iters of that function to
36
+ record. Defaults to all of them.
37
+ """
38
+ def nvtx(func):
39
+
40
+ nonlocal name
41
+ iters_so_far = 0
42
+ if name is None:
43
+ name = func.__name__
44
+
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ nonlocal iters_so_far
48
+ if not torch.cuda.is_available():
49
+ return func(*args, **kwargs)
50
+ elif iters_so_far < n_warmups:
51
+ iters_so_far += 1
52
+ return func(*args, **kwargs)
53
+ elif n_iters is None or iters_so_far < n_iters + n_warmups:
54
+ prev_state = begin_cuda_profile()
55
+ torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
56
+ result = func(*args, **kwargs)
57
+ torch.cuda.nvtx.range_pop()
58
+ end_cuda_profile(prev_state)
59
+ iters_so_far += 1
60
+ return result
61
+ else:
62
+ return func(*args, **kwargs)
63
+
64
+ return wrapper
65
+
66
+ return nvtx
@@ -5,6 +5,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
5
5
  import torch
6
6
  import torch.profiler as torch_profiler
7
7
 
8
+ import torch_geometric.typing
9
+
8
10
  # predefined namedtuple for variable setting (global template)
9
11
  Trace = namedtuple('Trace', ['path', 'leaf', 'module'])
10
12
 
@@ -325,41 +327,50 @@ def _flatten_tree(t, depth=0):
325
327
 
326
328
 
327
329
  def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
330
+ device_str = 'device' if torch_geometric.typing.WITH_PT24 else 'cuda'
331
+
328
332
  # memory profiling supported in torch >= 1.6
329
333
  self_cpu_memory = None
330
334
  has_self_cpu_memory = any(
331
335
  hasattr(e, "self_cpu_memory_usage") for e in events)
332
336
  if has_self_cpu_memory:
333
337
  self_cpu_memory = sum(
334
- [getattr(e, "self_cpu_memory_usage", 0) for e in events])
338
+ [getattr(e, "self_cpu_memory_usage", 0) or 0 for e in events])
335
339
  cpu_memory = None
336
340
  has_cpu_memory = any(hasattr(e, "cpu_memory_usage") for e in events)
337
341
  if has_cpu_memory:
338
- cpu_memory = sum([getattr(e, "cpu_memory_usage", 0) for e in events])
342
+ cpu_memory = sum(
343
+ [getattr(e, "cpu_memory_usage", 0) or 0 for e in events])
339
344
  self_cuda_memory = None
340
345
  has_self_cuda_memory = any(
341
- hasattr(e, "self_cuda_memory_usage") for e in events)
346
+ hasattr(e, f"self_{device_str}_memory_usage") for e in events)
342
347
  if has_self_cuda_memory:
343
- self_cuda_memory = sum(
344
- [getattr(e, "self_cuda_memory_usage", 0) for e in events])
348
+ self_cuda_memory = sum([
349
+ getattr(e, f"self_{device_str}_memory_usage", 0) or 0
350
+ for e in events
351
+ ])
345
352
  cuda_memory = None
346
- has_cuda_memory = any(hasattr(e, "cuda_memory_usage") for e in events)
353
+ has_cuda_memory = any(
354
+ hasattr(e, f"{device_str}_memory_usage") for e in events)
347
355
  if has_cuda_memory:
348
- cuda_memory = sum([getattr(e, "cuda_memory_usage", 0) for e in events])
356
+ cuda_memory = sum(
357
+ [getattr(e, f"{device_str}_memory_usage", 0) or 0 for e in events])
349
358
 
350
359
  # self CUDA time supported in torch >= 1.7
351
360
  self_cuda_total = None
352
361
  has_self_cuda_time = any(
353
- hasattr(e, "self_cuda_time_total") for e in events)
362
+ hasattr(e, f"self_{device_str}_time_total") for e in events)
354
363
  if has_self_cuda_time:
355
- self_cuda_total = sum(
356
- [getattr(e, "self_cuda_time_total", 0) for e in events])
364
+ self_cuda_total = sum([
365
+ getattr(e, f"self_{device_str}_time_total", 0) or 0 for e in events
366
+ ])
357
367
 
358
368
  return Measure(
359
- self_cpu_total=sum([e.self_cpu_time_total for e in events]),
360
- cpu_total=sum([e.cpu_time_total for e in events]),
369
+ self_cpu_total=sum([e.self_cpu_time_total or 0 for e in events]),
370
+ cpu_total=sum([e.cpu_time_total or 0 for e in events]),
361
371
  self_cuda_total=self_cuda_total,
362
- cuda_total=sum([e.cuda_time_total for e in events]),
372
+ cuda_total=sum(
373
+ [getattr(e, f"{device_str}_time_total") or 0 for e in events]),
363
374
  self_cpu_memory=self_cpu_memory,
364
375
  cpu_memory=cpu_memory,
365
376
  self_cuda_memory=self_cuda_memory,
@@ -434,10 +445,10 @@ def format_time(time_us: int) -> str:
434
445
  US_IN_SECOND = 1000.0 * 1000.0
435
446
  US_IN_MS = 1000.0
436
447
  if time_us >= US_IN_SECOND:
437
- return '{:.3f}s'.format(time_us / US_IN_SECOND)
448
+ return f'{time_us / US_IN_SECOND:.3f}s'
438
449
  if time_us >= US_IN_MS:
439
- return '{:.3f}ms'.format(time_us / US_IN_MS)
440
- return '{:.3f}us'.format(time_us)
450
+ return f'{time_us / US_IN_MS:.3f}ms'
451
+ return f'{time_us:.3f}us'
441
452
 
442
453
 
443
454
  def format_memory(nbytes: int) -> str:
@@ -446,10 +457,10 @@ def format_memory(nbytes: int) -> str:
446
457
  MB = 1024 * KB
447
458
  GB = 1024 * MB
448
459
  if (abs(nbytes) >= GB):
449
- return '{:.2f} Gb'.format(nbytes * 1.0 / GB)
460
+ return f'{nbytes * 1.0 / GB:.2f} Gb'
450
461
  elif (abs(nbytes) >= MB):
451
- return '{:.2f} Mb'.format(nbytes * 1.0 / MB)
462
+ return f'{nbytes * 1.0 / MB:.2f} Mb'
452
463
  elif (abs(nbytes) >= KB):
453
- return '{:.2f} Kb'.format(nbytes * 1.0 / KB)
464
+ return f'{nbytes * 1.0 / KB:.2f} Kb'
454
465
  else:
455
466
  return str(nbytes) + ' b'
@@ -39,5 +39,5 @@ def resolver(
39
39
  return obj
40
40
  return cls
41
41
 
42
- choices = set(cls.__name__ for cls in classes) | set(class_dict.keys())
42
+ choices = {cls.__name__ for cls in classes} | set(class_dict.keys())
43
43
  raise ValueError(f"Could not resolve '{query}' among choices {choices}")