pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -50,7 +50,6 @@ class ExplainerAlgorithm(torch.nn.Module):
|
|
50
50
|
r"""Checks if the explainer supports the user-defined settings provided
|
51
51
|
in :obj:`self.explainer_config`, :obj:`self.model_config`.
|
52
52
|
"""
|
53
|
-
pass
|
54
53
|
|
55
54
|
###########################################################################
|
56
55
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import math
|
2
2
|
from typing import List, Optional, Tuple, Union
|
3
3
|
|
4
|
-
import numpy as np
|
5
4
|
import torch
|
6
5
|
import torch.nn.functional as F
|
7
6
|
from torch import Tensor
|
@@ -162,7 +161,7 @@ class GraphMaskExplainer(ExplainerAlgorithm):
|
|
162
161
|
(torch.log(u) - torch.log(1 - u) + input_element) / beta)
|
163
162
|
|
164
163
|
penalty = torch.sigmoid(input_element -
|
165
|
-
beta *
|
164
|
+
beta * math.log(-gamma / zeta))
|
166
165
|
else:
|
167
166
|
s = torch.sigmoid(input_element)
|
168
167
|
penalty = torch.zeros_like(input_element)
|
@@ -340,10 +340,10 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
|
|
340
340
|
"""
|
341
341
|
node_mask_dict = self.node_mask_dict
|
342
342
|
for node_mask in node_mask_dict.values():
|
343
|
-
if node_mask.dim() != 2
|
343
|
+
if node_mask.dim() != 2:
|
344
344
|
raise ValueError(f"Cannot compute feature importance for "
|
345
345
|
f"object-level 'node_mask' "
|
346
|
-
f"(got shape {
|
346
|
+
f"(got shape {node_mask.size()})")
|
347
347
|
|
348
348
|
if feat_labels is None:
|
349
349
|
feat_labels = {}
|
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from torch_geometric.graphgym.config import cfg
|
9
|
+
from torch_geometric.io import fs
|
9
10
|
|
10
11
|
MODEL_STATE = 'model_state'
|
11
12
|
OPTIMIZER_STATE = 'optimizer_state'
|
@@ -25,7 +26,7 @@ def load_ckpt(
|
|
25
26
|
if not osp.exists(path):
|
26
27
|
return 0
|
27
28
|
|
28
|
-
ckpt =
|
29
|
+
ckpt = fs.torch_load(path)
|
29
30
|
model.load_state_dict(ckpt[MODEL_STATE])
|
30
31
|
if optimizer is not None and OPTIMIZER_STATE in ckpt:
|
31
32
|
optimizer.load_state_dict(ckpt[OPTIMIZER_STATE])
|
@@ -19,7 +19,7 @@ def set_printing():
|
|
19
19
|
logging.root.handlers = []
|
20
20
|
logging_cfg = {'level': logging.INFO, 'format': '%(message)s'}
|
21
21
|
os.makedirs(cfg.run_dir, exist_ok=True)
|
22
|
-
h_file = logging.FileHandler('{}/logging.log'
|
22
|
+
h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log')
|
23
23
|
h_stdout = logging.StreamHandler(sys.stdout)
|
24
24
|
if cfg.print == 'file':
|
25
25
|
logging_cfg['handlers'] = [h_file]
|
@@ -40,7 +40,7 @@ class Logger:
|
|
40
40
|
self._epoch_total = cfg.optim.max_epoch
|
41
41
|
self._time_total = 0 # won't be reset
|
42
42
|
|
43
|
-
self.out_dir = '{}/{}'
|
43
|
+
self.out_dir = f'{cfg.run_dir}/{name}'
|
44
44
|
os.makedirs(self.out_dir, exist_ok=True)
|
45
45
|
if cfg.tensorboard_each_run:
|
46
46
|
from tensorboardX import SummaryWriter
|
@@ -210,9 +210,9 @@ class Logger:
|
|
210
210
|
}
|
211
211
|
|
212
212
|
# print
|
213
|
-
logging.info('{}: {}'
|
213
|
+
logging.info(f'{self.name}: {stats}')
|
214
214
|
# json
|
215
|
-
dict_to_json(stats, '{}/stats.json'
|
215
|
+
dict_to_json(stats, f'{self.out_dir}/stats.json')
|
216
216
|
# tensorboard
|
217
217
|
if cfg.tensorboard_each_run:
|
218
218
|
dict_to_tb(stats, self.tb_writer, cur_epoch)
|
torch_geometric/graphgym/loss.py
CHANGED
@@ -54,7 +54,7 @@ def agg_dict_list(dict_list):
|
|
54
54
|
if key != 'epoch':
|
55
55
|
value = np.array([dict[key] for dict in dict_list])
|
56
56
|
dict_agg[key] = np.mean(value).round(cfg.round)
|
57
|
-
dict_agg['{}_std'
|
57
|
+
dict_agg[f'{key}_std'] = np.std(value).round(cfg.round)
|
58
58
|
return dict_agg
|
59
59
|
|
60
60
|
|
@@ -107,7 +107,7 @@ def agg_runs(dir, metric_best='auto'):
|
|
107
107
|
[stats[metric] for stats in stats_list])
|
108
108
|
best_epoch = \
|
109
109
|
stats_list[
|
110
|
-
eval("performance_np.{}()"
|
110
|
+
eval(f"performance_np.{cfg.metric_agg}()")][
|
111
111
|
'epoch']
|
112
112
|
print(best_epoch)
|
113
113
|
|
@@ -190,7 +190,7 @@ def agg_batch(dir, metric_best='auto'):
|
|
190
190
|
results[key] = pd.DataFrame(results[key])
|
191
191
|
results[key] = results[key].sort_values(
|
192
192
|
list(dict_name.keys()), ascending=[True] * len(dict_name))
|
193
|
-
fname = osp.join(dir_out, '{}_best.csv'
|
193
|
+
fname = osp.join(dir_out, f'{key}_best.csv')
|
194
194
|
results[key].to_csv(fname, index=False)
|
195
195
|
|
196
196
|
results = {'train': [], 'val': [], 'test': []}
|
@@ -213,7 +213,7 @@ def agg_batch(dir, metric_best='auto'):
|
|
213
213
|
results[key] = pd.DataFrame(results[key])
|
214
214
|
results[key] = results[key].sort_values(
|
215
215
|
list(dict_name.keys()), ascending=[True] * len(dict_name))
|
216
|
-
fname = osp.join(dir_out, '{}.csv'
|
216
|
+
fname = osp.join(dir_out, f'{key}.csv')
|
217
217
|
results[key].to_csv(fname, index=False)
|
218
218
|
|
219
219
|
results = {'train': [], 'val': [], 'test': []}
|
@@ -245,7 +245,7 @@ def agg_batch(dir, metric_best='auto'):
|
|
245
245
|
results[key] = pd.DataFrame(results[key])
|
246
246
|
results[key] = results[key].sort_values(
|
247
247
|
list(dict_name.keys()), ascending=[True] * len(dict_name))
|
248
|
-
fname = osp.join(dir_out, '{}_bestepoch.csv'
|
248
|
+
fname = osp.join(dir_out, f'{key}_bestepoch.csv')
|
249
249
|
results[key].to_csv(fname, index=False)
|
250
250
|
|
251
|
-
print('Results aggregated across models saved in {}'
|
251
|
+
print(f'Results aggregated across models saved in {dir_out}')
|