pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +13 -7
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +317 -65
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +3 -5
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +329 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +56 -22
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -50,7 +50,6 @@ class ExplainerAlgorithm(torch.nn.Module):
50
50
  r"""Checks if the explainer supports the user-defined settings provided
51
51
  in :obj:`self.explainer_config`, :obj:`self.model_config`.
52
52
  """
53
- pass
54
53
 
55
54
  ###########################################################################
56
55
 
@@ -1,7 +1,6 @@
1
1
  import math
2
2
  from typing import List, Optional, Tuple, Union
3
3
 
4
- import numpy as np
5
4
  import torch
6
5
  import torch.nn.functional as F
7
6
  from torch import Tensor
@@ -162,7 +161,7 @@ class GraphMaskExplainer(ExplainerAlgorithm):
162
161
  (torch.log(u) - torch.log(1 - u) + input_element) / beta)
163
162
 
164
163
  penalty = torch.sigmoid(input_element -
165
- beta * np.math.log(-gamma / zeta))
164
+ beta * math.log(-gamma / zeta))
166
165
  else:
167
166
  s = torch.sigmoid(input_element)
168
167
  penalty = torch.zeros_like(input_element)
@@ -59,7 +59,7 @@ class PGExplainer(ExplainerAlgorithm):
59
59
  'edge_size': 0.05,
60
60
  'edge_ent': 1.0,
61
61
  'temp': [5.0, 2.0],
62
- 'bias': 0.0,
62
+ 'bias': 0.01,
63
63
  }
64
64
 
65
65
  def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
@@ -340,10 +340,10 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
340
340
  """
341
341
  node_mask_dict = self.node_mask_dict
342
342
  for node_mask in node_mask_dict.values():
343
- if node_mask.dim() != 2 or node_mask.size(1) <= 1:
343
+ if node_mask.dim() != 2:
344
344
  raise ValueError(f"Cannot compute feature importance for "
345
345
  f"object-level 'node_mask' "
346
- f"(got shape {node_mask_dict.size()})")
346
+ f"(got shape {node_mask.size()})")
347
347
 
348
348
  if feat_labels is None:
349
349
  feat_labels = {}
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
6
6
  import torch
7
7
 
8
8
  from torch_geometric.graphgym.config import cfg
9
+ from torch_geometric.io import fs
9
10
 
10
11
  MODEL_STATE = 'model_state'
11
12
  OPTIMIZER_STATE = 'optimizer_state'
@@ -25,7 +26,7 @@ def load_ckpt(
25
26
  if not osp.exists(path):
26
27
  return 0
27
28
 
28
- ckpt = torch.load(path)
29
+ ckpt = fs.torch_load(path)
29
30
  model.load_state_dict(ckpt[MODEL_STATE])
30
31
  if optimizer is not None and OPTIMIZER_STATE in ckpt:
31
32
  optimizer.load_state_dict(ckpt[OPTIMIZER_STATE])
@@ -19,7 +19,7 @@ def set_printing():
19
19
  logging.root.handlers = []
20
20
  logging_cfg = {'level': logging.INFO, 'format': '%(message)s'}
21
21
  os.makedirs(cfg.run_dir, exist_ok=True)
22
- h_file = logging.FileHandler('{}/logging.log'.format(cfg.run_dir))
22
+ h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log')
23
23
  h_stdout = logging.StreamHandler(sys.stdout)
24
24
  if cfg.print == 'file':
25
25
  logging_cfg['handlers'] = [h_file]
@@ -40,7 +40,7 @@ class Logger:
40
40
  self._epoch_total = cfg.optim.max_epoch
41
41
  self._time_total = 0 # won't be reset
42
42
 
43
- self.out_dir = '{}/{}'.format(cfg.run_dir, name)
43
+ self.out_dir = f'{cfg.run_dir}/{name}'
44
44
  os.makedirs(self.out_dir, exist_ok=True)
45
45
  if cfg.tensorboard_each_run:
46
46
  from tensorboardX import SummaryWriter
@@ -210,9 +210,9 @@ class Logger:
210
210
  }
211
211
 
212
212
  # print
213
- logging.info('{}: {}'.format(self.name, stats))
213
+ logging.info(f'{self.name}: {stats}')
214
214
  # json
215
- dict_to_json(stats, '{}/stats.json'.format(self.out_dir))
215
+ dict_to_json(stats, f'{self.out_dir}/stats.json')
216
216
  # tensorboard
217
217
  if cfg.tensorboard_each_run:
218
218
  dict_to_tb(stats, self.tb_writer, cur_epoch)
@@ -10,7 +10,7 @@ def compute_loss(pred, true):
10
10
 
11
11
  Args:
12
12
  pred (torch.tensor): Unnormalized prediction
13
- true (torch.tensor): Grou
13
+ true (torch.tensor): Ground truth labels
14
14
 
15
15
  Returns: Loss, normalized prediction score
16
16
 
@@ -54,7 +54,7 @@ def agg_dict_list(dict_list):
54
54
  if key != 'epoch':
55
55
  value = np.array([dict[key] for dict in dict_list])
56
56
  dict_agg[key] = np.mean(value).round(cfg.round)
57
- dict_agg['{}_std'.format(key)] = np.std(value).round(cfg.round)
57
+ dict_agg[f'{key}_std'] = np.std(value).round(cfg.round)
58
58
  return dict_agg
59
59
 
60
60
 
@@ -107,7 +107,7 @@ def agg_runs(dir, metric_best='auto'):
107
107
  [stats[metric] for stats in stats_list])
108
108
  best_epoch = \
109
109
  stats_list[
110
- eval("performance_np.{}()".format(cfg.metric_agg))][
110
+ eval(f"performance_np.{cfg.metric_agg}()")][
111
111
  'epoch']
112
112
  print(best_epoch)
113
113
 
@@ -190,7 +190,7 @@ def agg_batch(dir, metric_best='auto'):
190
190
  results[key] = pd.DataFrame(results[key])
191
191
  results[key] = results[key].sort_values(
192
192
  list(dict_name.keys()), ascending=[True] * len(dict_name))
193
- fname = osp.join(dir_out, '{}_best.csv'.format(key))
193
+ fname = osp.join(dir_out, f'{key}_best.csv')
194
194
  results[key].to_csv(fname, index=False)
195
195
 
196
196
  results = {'train': [], 'val': [], 'test': []}
@@ -213,7 +213,7 @@ def agg_batch(dir, metric_best='auto'):
213
213
  results[key] = pd.DataFrame(results[key])
214
214
  results[key] = results[key].sort_values(
215
215
  list(dict_name.keys()), ascending=[True] * len(dict_name))
216
- fname = osp.join(dir_out, '{}.csv'.format(key))
216
+ fname = osp.join(dir_out, f'{key}.csv')
217
217
  results[key].to_csv(fname, index=False)
218
218
 
219
219
  results = {'train': [], 'val': [], 'test': []}
@@ -245,7 +245,7 @@ def agg_batch(dir, metric_best='auto'):
245
245
  results[key] = pd.DataFrame(results[key])
246
246
  results[key] = results[key].sort_values(
247
247
  list(dict_name.keys()), ascending=[True] * len(dict_name))
248
- fname = osp.join(dir_out, '{}_bestepoch.csv'.format(key))
248
+ fname = osp.join(dir_out, f'{key}_bestepoch.csv')
249
249
  results[key].to_csv(fname, index=False)
250
250
 
251
- print('Results aggregated across models saved in {}'.format(dir_out))
251
+ print(f'Results aggregated across models saved in {dir_out}')