pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,677 @@
|
|
1
|
+
import os
|
2
|
+
import pickle as pkl
|
3
|
+
import shutil
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from itertools import chain
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
Dict,
|
10
|
+
Iterable,
|
11
|
+
Iterator,
|
12
|
+
List,
|
13
|
+
Optional,
|
14
|
+
Sequence,
|
15
|
+
Set,
|
16
|
+
Tuple,
|
17
|
+
Union,
|
18
|
+
)
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from torch import Tensor
|
22
|
+
from tqdm import tqdm
|
23
|
+
|
24
|
+
from torch_geometric.data import Data
|
25
|
+
from torch_geometric.typing import WITH_PT24
|
26
|
+
|
27
|
+
# Could be any hashable type
|
28
|
+
TripletLike = Tuple[str, str, str]
|
29
|
+
|
30
|
+
KnowledgeGraphLike = Iterable[TripletLike]
|
31
|
+
|
32
|
+
|
33
|
+
def ordered_set(values: Iterable[str]) -> List[str]:
|
34
|
+
return list(dict.fromkeys(values))
|
35
|
+
|
36
|
+
|
37
|
+
# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
|
38
|
+
|
39
|
+
NODE_PID = "pid"
|
40
|
+
|
41
|
+
NODE_KEYS = {NODE_PID}
|
42
|
+
|
43
|
+
EDGE_PID = "e_pid"
|
44
|
+
EDGE_HEAD = "h"
|
45
|
+
EDGE_RELATION = "r"
|
46
|
+
EDGE_TAIL = "t"
|
47
|
+
EDGE_INDEX = "edge_idx"
|
48
|
+
|
49
|
+
EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
|
50
|
+
|
51
|
+
FeatureValueType = Union[Sequence[Any], Tensor]
|
52
|
+
|
53
|
+
|
54
|
+
@dataclass
|
55
|
+
class MappedFeature:
|
56
|
+
name: str
|
57
|
+
values: FeatureValueType
|
58
|
+
|
59
|
+
def __eq__(self, value: "MappedFeature") -> bool:
|
60
|
+
eq = self.name == value.name
|
61
|
+
if isinstance(self.values, torch.Tensor):
|
62
|
+
eq &= torch.equal(self.values, value.values)
|
63
|
+
else:
|
64
|
+
eq &= self.values == value.values
|
65
|
+
return eq
|
66
|
+
|
67
|
+
|
68
|
+
if WITH_PT24:
|
69
|
+
torch.serialization.add_safe_globals([MappedFeature])
|
70
|
+
|
71
|
+
|
72
|
+
class LargeGraphIndexer:
|
73
|
+
"""For a dataset that consists of multiple subgraphs that are assumed to
|
74
|
+
be part of a much larger graph, collate the values into a large graph store
|
75
|
+
to save resources.
|
76
|
+
"""
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
nodes: Iterable[str],
|
80
|
+
edges: KnowledgeGraphLike,
|
81
|
+
node_attr: Optional[Dict[str, List[Any]]] = None,
|
82
|
+
edge_attr: Optional[Dict[str, List[Any]]] = None,
|
83
|
+
) -> None:
|
84
|
+
r"""Constructs a new index that uniquely catalogs each node and edge
|
85
|
+
by id. Not meant to be used directly.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
nodes (Iterable[str]): Node ids in the graph.
|
89
|
+
edges (KnowledgeGraphLike): Edge ids in the graph.
|
90
|
+
node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
|
91
|
+
attribute name and list of their values in order of unique node
|
92
|
+
ids. Defaults to None.
|
93
|
+
edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
|
94
|
+
attribute name and list of their values in order of unique edge
|
95
|
+
ids. Defaults to None.
|
96
|
+
"""
|
97
|
+
self._nodes: Dict[str, int] = dict()
|
98
|
+
self._edges: Dict[TripletLike, int] = dict()
|
99
|
+
|
100
|
+
self._mapped_node_features: Set[str] = set()
|
101
|
+
self._mapped_edge_features: Set[str] = set()
|
102
|
+
|
103
|
+
if len(nodes) != len(set(nodes)):
|
104
|
+
raise AttributeError("Nodes need to be unique")
|
105
|
+
if len(edges) != len(set(edges)):
|
106
|
+
raise AttributeError("Edges need to be unique")
|
107
|
+
|
108
|
+
if node_attr is not None:
|
109
|
+
# TODO: Validity checks btw nodes and node_attr
|
110
|
+
self.node_attr = node_attr
|
111
|
+
if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
|
112
|
+
raise AttributeError(
|
113
|
+
"Invalid node_attr object. Missing " +
|
114
|
+
f"{NODE_KEYS - set(self.node_attr.keys())}")
|
115
|
+
elif self.node_attr[NODE_PID] != nodes:
|
116
|
+
raise AttributeError(
|
117
|
+
"Nodes provided do not match those in node_attr")
|
118
|
+
else:
|
119
|
+
self.node_attr = dict()
|
120
|
+
self.node_attr[NODE_PID] = nodes
|
121
|
+
|
122
|
+
for i, node in enumerate(self.node_attr[NODE_PID]):
|
123
|
+
self._nodes[node] = i
|
124
|
+
|
125
|
+
if edge_attr is not None:
|
126
|
+
# TODO: Validity checks btw edges and edge_attr
|
127
|
+
self.edge_attr = edge_attr
|
128
|
+
|
129
|
+
if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
|
130
|
+
raise AttributeError(
|
131
|
+
"Invalid edge_attr object. Missing " +
|
132
|
+
f"{EDGE_KEYS - set(self.edge_attr.keys())}")
|
133
|
+
elif self.node_attr[EDGE_PID] != edges:
|
134
|
+
raise AttributeError(
|
135
|
+
"Edges provided do not match those in edge_attr")
|
136
|
+
|
137
|
+
else:
|
138
|
+
self.edge_attr = dict()
|
139
|
+
for default_key in EDGE_KEYS:
|
140
|
+
self.edge_attr[default_key] = list()
|
141
|
+
self.edge_attr[EDGE_PID] = edges
|
142
|
+
|
143
|
+
for i, tup in enumerate(edges):
|
144
|
+
h, r, t = tup
|
145
|
+
self.edge_attr[EDGE_HEAD].append(h)
|
146
|
+
self.edge_attr[EDGE_RELATION].append(r)
|
147
|
+
self.edge_attr[EDGE_TAIL].append(t)
|
148
|
+
self.edge_attr[EDGE_INDEX].append(
|
149
|
+
(self._nodes[h], self._nodes[t]))
|
150
|
+
|
151
|
+
for i, tup in enumerate(edges):
|
152
|
+
self._edges[tup] = i
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def from_triplets(
|
156
|
+
cls,
|
157
|
+
triplets: KnowledgeGraphLike,
|
158
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
159
|
+
) -> "LargeGraphIndexer":
|
160
|
+
r"""Generate a new index from a series of triplets that represent edge
|
161
|
+
relations between nodes.
|
162
|
+
Formatted like (source_node, edge, dest_node).
|
163
|
+
|
164
|
+
Args:
|
165
|
+
triplets (KnowledgeGraphLike): Series of triplets representing
|
166
|
+
knowledge graph relations.
|
167
|
+
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
168
|
+
Optional preprocessing function to apply to triplets.
|
169
|
+
Defaults to None.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
LargeGraphIndexer: Index of unique nodes and edges.
|
173
|
+
"""
|
174
|
+
# NOTE: Right now assumes that all trips can be loaded into memory
|
175
|
+
nodes = set()
|
176
|
+
edges = set()
|
177
|
+
|
178
|
+
if pre_transform is not None:
|
179
|
+
|
180
|
+
def apply_transform(
|
181
|
+
trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
|
182
|
+
for trip in trips:
|
183
|
+
yield pre_transform(trip)
|
184
|
+
|
185
|
+
triplets = apply_transform(triplets)
|
186
|
+
|
187
|
+
for h, r, t in triplets:
|
188
|
+
|
189
|
+
for node in (h, t):
|
190
|
+
nodes.add(node)
|
191
|
+
|
192
|
+
edge_idx = (h, r, t)
|
193
|
+
edges.add(edge_idx)
|
194
|
+
|
195
|
+
return cls(list(nodes), list(edges))
|
196
|
+
|
197
|
+
@classmethod
|
198
|
+
def collate(cls,
|
199
|
+
graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
|
200
|
+
r"""Combines a series of large graph indexes into a single large graph
|
201
|
+
index.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
graphs (Iterable[LargeGraphIndexer]): Indices to be
|
205
|
+
combined.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
LargeGraphIndexer: Singular unique index for all nodes and edges
|
209
|
+
in input indices.
|
210
|
+
"""
|
211
|
+
# FIXME Needs to merge node attrs and edge attrs?
|
212
|
+
trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
|
213
|
+
return cls.from_triplets(trips)
|
214
|
+
|
215
|
+
def get_unique_node_features(self,
|
216
|
+
feature_name: str = NODE_PID) -> List[str]:
|
217
|
+
r"""Get all the unique values for a specific node attribute.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
feature_name (str, optional): Name of feature to get.
|
221
|
+
Defaults to NODE_PID.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
List[str]: List of unique values for the specified feature.
|
225
|
+
"""
|
226
|
+
try:
|
227
|
+
if feature_name in self._mapped_node_features:
|
228
|
+
raise IndexError(
|
229
|
+
"Only non-mapped features can be retrieved uniquely.")
|
230
|
+
return ordered_set(self.get_node_features(feature_name))
|
231
|
+
|
232
|
+
except KeyError:
|
233
|
+
raise AttributeError(
|
234
|
+
f"Nodes do not have a feature called {feature_name}")
|
235
|
+
|
236
|
+
def add_node_feature(
|
237
|
+
self,
|
238
|
+
new_feature_name: str,
|
239
|
+
new_feature_vals: FeatureValueType,
|
240
|
+
map_from_feature: str = NODE_PID,
|
241
|
+
) -> None:
|
242
|
+
r"""Adds a new feature that corresponds to each unique node in
|
243
|
+
the graph.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
new_feature_name (str): Name to call the new feature.
|
247
|
+
new_feature_vals (FeatureValueType): Values to map for that
|
248
|
+
new feature.
|
249
|
+
map_from_feature (str, optional): Key of feature to map from.
|
250
|
+
Size must match the number of feature values.
|
251
|
+
Defaults to NODE_PID.
|
252
|
+
"""
|
253
|
+
if new_feature_name in self.node_attr:
|
254
|
+
raise AttributeError("Features cannot be overridden once created")
|
255
|
+
if map_from_feature in self._mapped_node_features:
|
256
|
+
raise AttributeError(
|
257
|
+
f"{map_from_feature} is already a feature mapping.")
|
258
|
+
|
259
|
+
feature_keys = self.get_unique_node_features(map_from_feature)
|
260
|
+
if len(feature_keys) != len(new_feature_vals):
|
261
|
+
raise AttributeError(
|
262
|
+
"Expected encodings for {len(feature_keys)} unique features," +
|
263
|
+
f" but got {len(new_feature_vals)} encodings.")
|
264
|
+
|
265
|
+
if map_from_feature == NODE_PID:
|
266
|
+
self.node_attr[new_feature_name] = new_feature_vals
|
267
|
+
else:
|
268
|
+
self.node_attr[new_feature_name] = MappedFeature(
|
269
|
+
name=map_from_feature, values=new_feature_vals)
|
270
|
+
self._mapped_node_features.add(new_feature_name)
|
271
|
+
|
272
|
+
def get_node_features(
|
273
|
+
self,
|
274
|
+
feature_name: str = NODE_PID,
|
275
|
+
pids: Optional[Iterable[str]] = None,
|
276
|
+
) -> List[Any]:
|
277
|
+
r"""Get node feature values for a given set of unique node ids.
|
278
|
+
Returned values are not necessarily unique.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
feature_name (str, optional): Name of feature to fetch. Defaults
|
282
|
+
to NODE_PID.
|
283
|
+
pids (Optional[Iterable[str]], optional): Node ids to fetch
|
284
|
+
for. Defaults to None, which fetches all nodes.
|
285
|
+
|
286
|
+
Returns:
|
287
|
+
List[Any]: Node features corresponding to the specified ids.
|
288
|
+
"""
|
289
|
+
if feature_name in self._mapped_node_features:
|
290
|
+
values = self.node_attr[feature_name].values
|
291
|
+
else:
|
292
|
+
values = self.node_attr[feature_name]
|
293
|
+
|
294
|
+
# TODO: torch_geometric.utils.select
|
295
|
+
if isinstance(values, torch.Tensor):
|
296
|
+
idxs = list(
|
297
|
+
self.get_node_features_iter(feature_name, pids,
|
298
|
+
index_only=True))
|
299
|
+
return values[idxs]
|
300
|
+
return list(self.get_node_features_iter(feature_name, pids))
|
301
|
+
|
302
|
+
def get_node_features_iter(
|
303
|
+
self,
|
304
|
+
feature_name: str = NODE_PID,
|
305
|
+
pids: Optional[Iterable[str]] = None,
|
306
|
+
index_only: bool = False,
|
307
|
+
) -> Iterator[Any]:
|
308
|
+
"""Iterator version of get_node_features. If index_only is True,
|
309
|
+
yields indices instead of values.
|
310
|
+
"""
|
311
|
+
if pids is None:
|
312
|
+
pids = self.node_attr[NODE_PID]
|
313
|
+
|
314
|
+
if feature_name in self._mapped_node_features:
|
315
|
+
feature_map_info = self.node_attr[feature_name]
|
316
|
+
from_feature_name, to_feature_vals = (
|
317
|
+
feature_map_info.name,
|
318
|
+
feature_map_info.values,
|
319
|
+
)
|
320
|
+
from_feature_vals = self.get_unique_node_features(
|
321
|
+
from_feature_name)
|
322
|
+
feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
|
323
|
+
|
324
|
+
for pid in pids:
|
325
|
+
idx = self._nodes[pid]
|
326
|
+
from_feature_val = self.node_attr[from_feature_name][idx]
|
327
|
+
to_feature_idx = feature_mapping[from_feature_val]
|
328
|
+
if index_only:
|
329
|
+
yield to_feature_idx
|
330
|
+
else:
|
331
|
+
yield to_feature_vals[to_feature_idx]
|
332
|
+
else:
|
333
|
+
for pid in pids:
|
334
|
+
idx = self._nodes[pid]
|
335
|
+
if index_only:
|
336
|
+
yield idx
|
337
|
+
else:
|
338
|
+
yield self.node_attr[feature_name][idx]
|
339
|
+
|
340
|
+
def get_unique_edge_features(self,
|
341
|
+
feature_name: str = EDGE_PID) -> List[str]:
|
342
|
+
r"""Get all the unique values for a specific edge attribute.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
feature_name (str, optional): Name of feature to get.
|
346
|
+
Defaults to EDGE_PID.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
List[str]: List of unique values for the specified feature.
|
350
|
+
"""
|
351
|
+
try:
|
352
|
+
if feature_name in self._mapped_edge_features:
|
353
|
+
raise IndexError(
|
354
|
+
"Only non-mapped features can be retrieved uniquely.")
|
355
|
+
return ordered_set(self.get_edge_features(feature_name))
|
356
|
+
except KeyError:
|
357
|
+
raise AttributeError(
|
358
|
+
f"Edges do not have a feature called {feature_name}")
|
359
|
+
|
360
|
+
def add_edge_feature(
|
361
|
+
self,
|
362
|
+
new_feature_name: str,
|
363
|
+
new_feature_vals: FeatureValueType,
|
364
|
+
map_from_feature: str = EDGE_PID,
|
365
|
+
) -> None:
|
366
|
+
r"""Adds a new feature that corresponds to each unique edge in
|
367
|
+
the graph.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
new_feature_name (str): Name to call the new feature.
|
371
|
+
new_feature_vals (FeatureValueType): Values to map for that new
|
372
|
+
feature.
|
373
|
+
map_from_feature (str, optional): Key of feature to map from.
|
374
|
+
Size must match the number of feature values.
|
375
|
+
Defaults to EDGE_PID.
|
376
|
+
"""
|
377
|
+
if new_feature_name in self.edge_attr:
|
378
|
+
raise AttributeError("Features cannot be overridden once created")
|
379
|
+
if map_from_feature in self._mapped_edge_features:
|
380
|
+
raise AttributeError(
|
381
|
+
f"{map_from_feature} is already a feature mapping.")
|
382
|
+
|
383
|
+
feature_keys = self.get_unique_edge_features(map_from_feature)
|
384
|
+
if len(feature_keys) != len(new_feature_vals):
|
385
|
+
raise AttributeError(
|
386
|
+
f"Expected encodings for {len(feature_keys)} unique features, "
|
387
|
+
+ f"but got {len(new_feature_vals)} encodings.")
|
388
|
+
|
389
|
+
if map_from_feature == EDGE_PID:
|
390
|
+
self.edge_attr[new_feature_name] = new_feature_vals
|
391
|
+
else:
|
392
|
+
self.edge_attr[new_feature_name] = MappedFeature(
|
393
|
+
name=map_from_feature, values=new_feature_vals)
|
394
|
+
self._mapped_edge_features.add(new_feature_name)
|
395
|
+
|
396
|
+
def get_edge_features(
|
397
|
+
self,
|
398
|
+
feature_name: str = EDGE_PID,
|
399
|
+
pids: Optional[Iterable[str]] = None,
|
400
|
+
) -> List[Any]:
|
401
|
+
r"""Get edge feature values for a given set of unique edge ids.
|
402
|
+
Returned values are not necessarily unique.
|
403
|
+
|
404
|
+
Args:
|
405
|
+
feature_name (str, optional): Name of feature to fetch.
|
406
|
+
Defaults to EDGE_PID.
|
407
|
+
pids (Optional[Iterable[str]], optional): Edge ids to fetch
|
408
|
+
for. Defaults to None, which fetches all edges.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
List[Any]: Node features corresponding to the specified ids.
|
412
|
+
"""
|
413
|
+
if feature_name in self._mapped_edge_features:
|
414
|
+
values = self.edge_attr[feature_name].values
|
415
|
+
else:
|
416
|
+
values = self.edge_attr[feature_name]
|
417
|
+
|
418
|
+
# TODO: torch_geometric.utils.select
|
419
|
+
if isinstance(values, torch.Tensor):
|
420
|
+
idxs = list(
|
421
|
+
self.get_edge_features_iter(feature_name, pids,
|
422
|
+
index_only=True))
|
423
|
+
return values[idxs]
|
424
|
+
return list(self.get_edge_features_iter(feature_name, pids))
|
425
|
+
|
426
|
+
def get_edge_features_iter(
|
427
|
+
self,
|
428
|
+
feature_name: str = EDGE_PID,
|
429
|
+
pids: Optional[KnowledgeGraphLike] = None,
|
430
|
+
index_only: bool = False,
|
431
|
+
) -> Iterator[Any]:
|
432
|
+
"""Iterator version of get_edge_features. If index_only is True,
|
433
|
+
yields indices instead of values.
|
434
|
+
"""
|
435
|
+
if pids is None:
|
436
|
+
pids = self.edge_attr[EDGE_PID]
|
437
|
+
|
438
|
+
if feature_name in self._mapped_edge_features:
|
439
|
+
feature_map_info = self.edge_attr[feature_name]
|
440
|
+
from_feature_name, to_feature_vals = (
|
441
|
+
feature_map_info.name,
|
442
|
+
feature_map_info.values,
|
443
|
+
)
|
444
|
+
from_feature_vals = self.get_unique_edge_features(
|
445
|
+
from_feature_name)
|
446
|
+
feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
|
447
|
+
|
448
|
+
for pid in pids:
|
449
|
+
idx = self._edges[pid]
|
450
|
+
from_feature_val = self.edge_attr[from_feature_name][idx]
|
451
|
+
to_feature_idx = feature_mapping[from_feature_val]
|
452
|
+
if index_only:
|
453
|
+
yield to_feature_idx
|
454
|
+
else:
|
455
|
+
yield to_feature_vals[to_feature_idx]
|
456
|
+
else:
|
457
|
+
for pid in pids:
|
458
|
+
idx = self._edges[pid]
|
459
|
+
if index_only:
|
460
|
+
yield idx
|
461
|
+
else:
|
462
|
+
yield self.edge_attr[feature_name][idx]
|
463
|
+
|
464
|
+
def to_triplets(self) -> Iterator[TripletLike]:
|
465
|
+
return iter(self.edge_attr[EDGE_PID])
|
466
|
+
|
467
|
+
def save(self, path: str) -> None:
|
468
|
+
if os.path.exists(path):
|
469
|
+
shutil.rmtree(path)
|
470
|
+
os.makedirs(path, exist_ok=True)
|
471
|
+
with open(path + "/edges", "wb") as f:
|
472
|
+
pkl.dump(self._edges, f)
|
473
|
+
with open(path + "/nodes", "wb") as f:
|
474
|
+
pkl.dump(self._nodes, f)
|
475
|
+
|
476
|
+
with open(path + "/mapped_edges", "wb") as f:
|
477
|
+
pkl.dump(self._mapped_edge_features, f)
|
478
|
+
with open(path + "/mapped_nodes", "wb") as f:
|
479
|
+
pkl.dump(self._mapped_node_features, f)
|
480
|
+
|
481
|
+
node_attr_path = path + "/node_attr"
|
482
|
+
os.makedirs(node_attr_path, exist_ok=True)
|
483
|
+
for attr_name, vals in self.node_attr.items():
|
484
|
+
torch.save(vals, node_attr_path + f"/{attr_name}.pt")
|
485
|
+
|
486
|
+
edge_attr_path = path + "/edge_attr"
|
487
|
+
os.makedirs(edge_attr_path, exist_ok=True)
|
488
|
+
for attr_name, vals in self.edge_attr.items():
|
489
|
+
torch.save(vals, edge_attr_path + f"/{attr_name}.pt")
|
490
|
+
|
491
|
+
@classmethod
|
492
|
+
def from_disk(cls, path: str) -> "LargeGraphIndexer":
|
493
|
+
indexer = cls(list(), list())
|
494
|
+
with open(path + "/edges", "rb") as f:
|
495
|
+
indexer._edges = pkl.load(f)
|
496
|
+
with open(path + "/nodes", "rb") as f:
|
497
|
+
indexer._nodes = pkl.load(f)
|
498
|
+
|
499
|
+
with open(path + "/mapped_edges", "rb") as f:
|
500
|
+
indexer._mapped_edge_features = pkl.load(f)
|
501
|
+
with open(path + "/mapped_nodes", "rb") as f:
|
502
|
+
indexer._mapped_node_features = pkl.load(f)
|
503
|
+
|
504
|
+
node_attr_path = path + "/node_attr"
|
505
|
+
for fname in os.listdir(node_attr_path):
|
506
|
+
full_fname = f"{node_attr_path}/{fname}"
|
507
|
+
key = fname.split(".")[0]
|
508
|
+
indexer.node_attr[key] = torch.load(full_fname)
|
509
|
+
|
510
|
+
edge_attr_path = path + "/edge_attr"
|
511
|
+
for fname in os.listdir(edge_attr_path):
|
512
|
+
full_fname = f"{edge_attr_path}/{fname}"
|
513
|
+
key = fname.split(".")[0]
|
514
|
+
indexer.edge_attr[key] = torch.load(full_fname)
|
515
|
+
|
516
|
+
return indexer
|
517
|
+
|
518
|
+
def to_data(self, node_feature_name: str,
|
519
|
+
edge_feature_name: Optional[str] = None) -> Data:
|
520
|
+
"""Return a Data object containing all the specified node and
|
521
|
+
edge features and the graph.
|
522
|
+
|
523
|
+
Args:
|
524
|
+
node_feature_name (str): Feature to use for nodes
|
525
|
+
edge_feature_name (Optional[str], optional): Feature to use for
|
526
|
+
edges. Defaults to None.
|
527
|
+
|
528
|
+
Returns:
|
529
|
+
Data: Data object containing the specified node and
|
530
|
+
edge features and the graph.
|
531
|
+
"""
|
532
|
+
x = torch.Tensor(self.get_node_features(node_feature_name))
|
533
|
+
node_id = torch.LongTensor(range(len(x)))
|
534
|
+
|
535
|
+
edge_index = torch.t(
|
536
|
+
torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
|
537
|
+
|
538
|
+
edge_attr = (self.get_edge_features(edge_feature_name)
|
539
|
+
if edge_feature_name is not None else None)
|
540
|
+
edge_id = torch.LongTensor(range(len(edge_attr)))
|
541
|
+
|
542
|
+
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
543
|
+
edge_id=edge_id, node_id=node_id)
|
544
|
+
|
545
|
+
def __eq__(self, value: "LargeGraphIndexer") -> bool:
|
546
|
+
eq = True
|
547
|
+
eq &= self._nodes == value._nodes
|
548
|
+
eq &= self._edges == value._edges
|
549
|
+
eq &= self.node_attr.keys() == value.node_attr.keys()
|
550
|
+
eq &= self.edge_attr.keys() == value.edge_attr.keys()
|
551
|
+
eq &= self._mapped_node_features == value._mapped_node_features
|
552
|
+
eq &= self._mapped_edge_features == value._mapped_edge_features
|
553
|
+
|
554
|
+
for k in self.node_attr:
|
555
|
+
eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
|
556
|
+
if isinstance(self.node_attr[k], torch.Tensor):
|
557
|
+
eq &= torch.equal(self.node_attr[k], value.node_attr[k])
|
558
|
+
else:
|
559
|
+
eq &= self.node_attr[k] == value.node_attr[k]
|
560
|
+
for k in self.edge_attr:
|
561
|
+
eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
|
562
|
+
if isinstance(self.edge_attr[k], torch.Tensor):
|
563
|
+
eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
|
564
|
+
else:
|
565
|
+
eq &= self.edge_attr[k] == value.edge_attr[k]
|
566
|
+
return eq
|
567
|
+
|
568
|
+
|
569
|
+
def get_features_for_triplets_groups(
|
570
|
+
indexer: LargeGraphIndexer,
|
571
|
+
triplet_groups: Iterable[KnowledgeGraphLike],
|
572
|
+
node_feature_name: str = "x",
|
573
|
+
edge_feature_name: str = "edge_attr",
|
574
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
575
|
+
verbose: bool = False,
|
576
|
+
) -> Iterator[Data]:
|
577
|
+
"""Given an indexer and a series of triplet groups (like a dataset),
|
578
|
+
retrieve the specified node and edge features for each triplet from the
|
579
|
+
index.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
indexer (LargeGraphIndexer): Indexer containing desired features
|
583
|
+
triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
|
584
|
+
triplets to fetch features for
|
585
|
+
node_feature_name (str, optional): Node feature to fetch.
|
586
|
+
Defaults to "x".
|
587
|
+
edge_feature_name (str, optional): edge feature to fetch.
|
588
|
+
Defaults to "edge_attr".
|
589
|
+
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
590
|
+
Optional preprocessing to perform on triplets.
|
591
|
+
Defaults to None.
|
592
|
+
verbose (bool, optional): Whether to print progress. Defaults to False.
|
593
|
+
|
594
|
+
Yields:
|
595
|
+
Iterator[Data]: For each triplet group, yield a data object containing
|
596
|
+
the unique graph and features from the index.
|
597
|
+
"""
|
598
|
+
if pre_transform is not None:
|
599
|
+
|
600
|
+
def apply_transform(trips):
|
601
|
+
for trip in trips:
|
602
|
+
yield pre_transform(tuple(trip))
|
603
|
+
|
604
|
+
# TODO: Make this safe for large amounts of triplets?
|
605
|
+
triplet_groups = (list(apply_transform(triplets))
|
606
|
+
for triplets in triplet_groups)
|
607
|
+
|
608
|
+
node_keys = []
|
609
|
+
edge_keys = []
|
610
|
+
edge_index = []
|
611
|
+
|
612
|
+
for triplets in tqdm(triplet_groups, disable=not verbose):
|
613
|
+
small_graph_indexer = LargeGraphIndexer.from_triplets(
|
614
|
+
triplets, pre_transform=pre_transform)
|
615
|
+
|
616
|
+
node_keys.append(small_graph_indexer.get_node_features())
|
617
|
+
edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
|
618
|
+
edge_index.append(
|
619
|
+
small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
|
620
|
+
|
621
|
+
node_feats = indexer.get_node_features(feature_name=node_feature_name,
|
622
|
+
pids=chain.from_iterable(node_keys))
|
623
|
+
edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
|
624
|
+
pids=chain.from_iterable(edge_keys))
|
625
|
+
|
626
|
+
last_node_idx, last_edge_idx = 0, 0
|
627
|
+
for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
|
628
|
+
nlen, elen = len(nkeys), len(ekeys)
|
629
|
+
x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
|
630
|
+
last_node_idx += len(nkeys)
|
631
|
+
|
632
|
+
edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
|
633
|
+
elen])
|
634
|
+
last_edge_idx += len(ekeys)
|
635
|
+
|
636
|
+
edge_idx = torch.LongTensor(eidx).T
|
637
|
+
|
638
|
+
data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
|
639
|
+
data_obj[NODE_PID] = node_keys
|
640
|
+
data_obj[EDGE_PID] = edge_keys
|
641
|
+
data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
|
642
|
+
data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
|
643
|
+
|
644
|
+
yield data_obj
|
645
|
+
|
646
|
+
|
647
|
+
def get_features_for_triplets(
|
648
|
+
indexer: LargeGraphIndexer,
|
649
|
+
triplets: KnowledgeGraphLike,
|
650
|
+
node_feature_name: str = "x",
|
651
|
+
edge_feature_name: str = "edge_attr",
|
652
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
653
|
+
verbose: bool = False,
|
654
|
+
) -> Data:
|
655
|
+
"""For a given set of triplets retrieve a Data object containing the
|
656
|
+
unique graph and features from the index.
|
657
|
+
|
658
|
+
Args:
|
659
|
+
indexer (LargeGraphIndexer): Indexer containing desired features
|
660
|
+
triplets (KnowledgeGraphLike): Triplets to fetch features for
|
661
|
+
node_feature_name (str, optional): Feature to use for node features.
|
662
|
+
Defaults to "x".
|
663
|
+
edge_feature_name (str, optional): Feature to use for edge features.
|
664
|
+
Defaults to "edge_attr".
|
665
|
+
pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
|
666
|
+
Optional preprocessing function for triplets. Defaults to None.
|
667
|
+
verbose (bool, optional): Whether to print progress. Defaults to False.
|
668
|
+
|
669
|
+
Returns:
|
670
|
+
Data: Data object containing the unique graph and features from the
|
671
|
+
index for the given triplets.
|
672
|
+
"""
|
673
|
+
gen = get_features_for_triplets_groups(indexer, [triplets],
|
674
|
+
node_feature_name,
|
675
|
+
edge_feature_name, pre_transform,
|
676
|
+
verbose)
|
677
|
+
return next(gen)
|