pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,677 @@
1
+ import os
2
+ import pickle as pkl
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from itertools import chain
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Iterable,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Sequence,
15
+ Set,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import torch
21
+ from torch import Tensor
22
+ from tqdm import tqdm
23
+
24
+ from torch_geometric.data import Data
25
+ from torch_geometric.typing import WITH_PT24
26
+
27
+ # Could be any hashable type
28
+ TripletLike = Tuple[str, str, str]
29
+
30
+ KnowledgeGraphLike = Iterable[TripletLike]
31
+
32
+
33
+ def ordered_set(values: Iterable[str]) -> List[str]:
34
+ return list(dict.fromkeys(values))
35
+
36
+
37
+ # TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?
38
+
39
+ NODE_PID = "pid"
40
+
41
+ NODE_KEYS = {NODE_PID}
42
+
43
+ EDGE_PID = "e_pid"
44
+ EDGE_HEAD = "h"
45
+ EDGE_RELATION = "r"
46
+ EDGE_TAIL = "t"
47
+ EDGE_INDEX = "edge_idx"
48
+
49
+ EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}
50
+
51
+ FeatureValueType = Union[Sequence[Any], Tensor]
52
+
53
+
54
+ @dataclass
55
+ class MappedFeature:
56
+ name: str
57
+ values: FeatureValueType
58
+
59
+ def __eq__(self, value: "MappedFeature") -> bool:
60
+ eq = self.name == value.name
61
+ if isinstance(self.values, torch.Tensor):
62
+ eq &= torch.equal(self.values, value.values)
63
+ else:
64
+ eq &= self.values == value.values
65
+ return eq
66
+
67
+
68
+ if WITH_PT24:
69
+ torch.serialization.add_safe_globals([MappedFeature])
70
+
71
+
72
+ class LargeGraphIndexer:
73
+ """For a dataset that consists of multiple subgraphs that are assumed to
74
+ be part of a much larger graph, collate the values into a large graph store
75
+ to save resources.
76
+ """
77
+ def __init__(
78
+ self,
79
+ nodes: Iterable[str],
80
+ edges: KnowledgeGraphLike,
81
+ node_attr: Optional[Dict[str, List[Any]]] = None,
82
+ edge_attr: Optional[Dict[str, List[Any]]] = None,
83
+ ) -> None:
84
+ r"""Constructs a new index that uniquely catalogs each node and edge
85
+ by id. Not meant to be used directly.
86
+
87
+ Args:
88
+ nodes (Iterable[str]): Node ids in the graph.
89
+ edges (KnowledgeGraphLike): Edge ids in the graph.
90
+ node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
91
+ attribute name and list of their values in order of unique node
92
+ ids. Defaults to None.
93
+ edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge
94
+ attribute name and list of their values in order of unique edge
95
+ ids. Defaults to None.
96
+ """
97
+ self._nodes: Dict[str, int] = dict()
98
+ self._edges: Dict[TripletLike, int] = dict()
99
+
100
+ self._mapped_node_features: Set[str] = set()
101
+ self._mapped_edge_features: Set[str] = set()
102
+
103
+ if len(nodes) != len(set(nodes)):
104
+ raise AttributeError("Nodes need to be unique")
105
+ if len(edges) != len(set(edges)):
106
+ raise AttributeError("Edges need to be unique")
107
+
108
+ if node_attr is not None:
109
+ # TODO: Validity checks btw nodes and node_attr
110
+ self.node_attr = node_attr
111
+ if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:
112
+ raise AttributeError(
113
+ "Invalid node_attr object. Missing " +
114
+ f"{NODE_KEYS - set(self.node_attr.keys())}")
115
+ elif self.node_attr[NODE_PID] != nodes:
116
+ raise AttributeError(
117
+ "Nodes provided do not match those in node_attr")
118
+ else:
119
+ self.node_attr = dict()
120
+ self.node_attr[NODE_PID] = nodes
121
+
122
+ for i, node in enumerate(self.node_attr[NODE_PID]):
123
+ self._nodes[node] = i
124
+
125
+ if edge_attr is not None:
126
+ # TODO: Validity checks btw edges and edge_attr
127
+ self.edge_attr = edge_attr
128
+
129
+ if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:
130
+ raise AttributeError(
131
+ "Invalid edge_attr object. Missing " +
132
+ f"{EDGE_KEYS - set(self.edge_attr.keys())}")
133
+ elif self.node_attr[EDGE_PID] != edges:
134
+ raise AttributeError(
135
+ "Edges provided do not match those in edge_attr")
136
+
137
+ else:
138
+ self.edge_attr = dict()
139
+ for default_key in EDGE_KEYS:
140
+ self.edge_attr[default_key] = list()
141
+ self.edge_attr[EDGE_PID] = edges
142
+
143
+ for i, tup in enumerate(edges):
144
+ h, r, t = tup
145
+ self.edge_attr[EDGE_HEAD].append(h)
146
+ self.edge_attr[EDGE_RELATION].append(r)
147
+ self.edge_attr[EDGE_TAIL].append(t)
148
+ self.edge_attr[EDGE_INDEX].append(
149
+ (self._nodes[h], self._nodes[t]))
150
+
151
+ for i, tup in enumerate(edges):
152
+ self._edges[tup] = i
153
+
154
+ @classmethod
155
+ def from_triplets(
156
+ cls,
157
+ triplets: KnowledgeGraphLike,
158
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
159
+ ) -> "LargeGraphIndexer":
160
+ r"""Generate a new index from a series of triplets that represent edge
161
+ relations between nodes.
162
+ Formatted like (source_node, edge, dest_node).
163
+
164
+ Args:
165
+ triplets (KnowledgeGraphLike): Series of triplets representing
166
+ knowledge graph relations.
167
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
168
+ Optional preprocessing function to apply to triplets.
169
+ Defaults to None.
170
+
171
+ Returns:
172
+ LargeGraphIndexer: Index of unique nodes and edges.
173
+ """
174
+ # NOTE: Right now assumes that all trips can be loaded into memory
175
+ nodes = set()
176
+ edges = set()
177
+
178
+ if pre_transform is not None:
179
+
180
+ def apply_transform(
181
+ trips: KnowledgeGraphLike) -> Iterator[TripletLike]:
182
+ for trip in trips:
183
+ yield pre_transform(trip)
184
+
185
+ triplets = apply_transform(triplets)
186
+
187
+ for h, r, t in triplets:
188
+
189
+ for node in (h, t):
190
+ nodes.add(node)
191
+
192
+ edge_idx = (h, r, t)
193
+ edges.add(edge_idx)
194
+
195
+ return cls(list(nodes), list(edges))
196
+
197
+ @classmethod
198
+ def collate(cls,
199
+ graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer":
200
+ r"""Combines a series of large graph indexes into a single large graph
201
+ index.
202
+
203
+ Args:
204
+ graphs (Iterable[LargeGraphIndexer]): Indices to be
205
+ combined.
206
+
207
+ Returns:
208
+ LargeGraphIndexer: Singular unique index for all nodes and edges
209
+ in input indices.
210
+ """
211
+ # FIXME Needs to merge node attrs and edge attrs?
212
+ trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
213
+ return cls.from_triplets(trips)
214
+
215
+ def get_unique_node_features(self,
216
+ feature_name: str = NODE_PID) -> List[str]:
217
+ r"""Get all the unique values for a specific node attribute.
218
+
219
+ Args:
220
+ feature_name (str, optional): Name of feature to get.
221
+ Defaults to NODE_PID.
222
+
223
+ Returns:
224
+ List[str]: List of unique values for the specified feature.
225
+ """
226
+ try:
227
+ if feature_name in self._mapped_node_features:
228
+ raise IndexError(
229
+ "Only non-mapped features can be retrieved uniquely.")
230
+ return ordered_set(self.get_node_features(feature_name))
231
+
232
+ except KeyError:
233
+ raise AttributeError(
234
+ f"Nodes do not have a feature called {feature_name}")
235
+
236
+ def add_node_feature(
237
+ self,
238
+ new_feature_name: str,
239
+ new_feature_vals: FeatureValueType,
240
+ map_from_feature: str = NODE_PID,
241
+ ) -> None:
242
+ r"""Adds a new feature that corresponds to each unique node in
243
+ the graph.
244
+
245
+ Args:
246
+ new_feature_name (str): Name to call the new feature.
247
+ new_feature_vals (FeatureValueType): Values to map for that
248
+ new feature.
249
+ map_from_feature (str, optional): Key of feature to map from.
250
+ Size must match the number of feature values.
251
+ Defaults to NODE_PID.
252
+ """
253
+ if new_feature_name in self.node_attr:
254
+ raise AttributeError("Features cannot be overridden once created")
255
+ if map_from_feature in self._mapped_node_features:
256
+ raise AttributeError(
257
+ f"{map_from_feature} is already a feature mapping.")
258
+
259
+ feature_keys = self.get_unique_node_features(map_from_feature)
260
+ if len(feature_keys) != len(new_feature_vals):
261
+ raise AttributeError(
262
+ "Expected encodings for {len(feature_keys)} unique features," +
263
+ f" but got {len(new_feature_vals)} encodings.")
264
+
265
+ if map_from_feature == NODE_PID:
266
+ self.node_attr[new_feature_name] = new_feature_vals
267
+ else:
268
+ self.node_attr[new_feature_name] = MappedFeature(
269
+ name=map_from_feature, values=new_feature_vals)
270
+ self._mapped_node_features.add(new_feature_name)
271
+
272
+ def get_node_features(
273
+ self,
274
+ feature_name: str = NODE_PID,
275
+ pids: Optional[Iterable[str]] = None,
276
+ ) -> List[Any]:
277
+ r"""Get node feature values for a given set of unique node ids.
278
+ Returned values are not necessarily unique.
279
+
280
+ Args:
281
+ feature_name (str, optional): Name of feature to fetch. Defaults
282
+ to NODE_PID.
283
+ pids (Optional[Iterable[str]], optional): Node ids to fetch
284
+ for. Defaults to None, which fetches all nodes.
285
+
286
+ Returns:
287
+ List[Any]: Node features corresponding to the specified ids.
288
+ """
289
+ if feature_name in self._mapped_node_features:
290
+ values = self.node_attr[feature_name].values
291
+ else:
292
+ values = self.node_attr[feature_name]
293
+
294
+ # TODO: torch_geometric.utils.select
295
+ if isinstance(values, torch.Tensor):
296
+ idxs = list(
297
+ self.get_node_features_iter(feature_name, pids,
298
+ index_only=True))
299
+ return values[idxs]
300
+ return list(self.get_node_features_iter(feature_name, pids))
301
+
302
+ def get_node_features_iter(
303
+ self,
304
+ feature_name: str = NODE_PID,
305
+ pids: Optional[Iterable[str]] = None,
306
+ index_only: bool = False,
307
+ ) -> Iterator[Any]:
308
+ """Iterator version of get_node_features. If index_only is True,
309
+ yields indices instead of values.
310
+ """
311
+ if pids is None:
312
+ pids = self.node_attr[NODE_PID]
313
+
314
+ if feature_name in self._mapped_node_features:
315
+ feature_map_info = self.node_attr[feature_name]
316
+ from_feature_name, to_feature_vals = (
317
+ feature_map_info.name,
318
+ feature_map_info.values,
319
+ )
320
+ from_feature_vals = self.get_unique_node_features(
321
+ from_feature_name)
322
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
323
+
324
+ for pid in pids:
325
+ idx = self._nodes[pid]
326
+ from_feature_val = self.node_attr[from_feature_name][idx]
327
+ to_feature_idx = feature_mapping[from_feature_val]
328
+ if index_only:
329
+ yield to_feature_idx
330
+ else:
331
+ yield to_feature_vals[to_feature_idx]
332
+ else:
333
+ for pid in pids:
334
+ idx = self._nodes[pid]
335
+ if index_only:
336
+ yield idx
337
+ else:
338
+ yield self.node_attr[feature_name][idx]
339
+
340
+ def get_unique_edge_features(self,
341
+ feature_name: str = EDGE_PID) -> List[str]:
342
+ r"""Get all the unique values for a specific edge attribute.
343
+
344
+ Args:
345
+ feature_name (str, optional): Name of feature to get.
346
+ Defaults to EDGE_PID.
347
+
348
+ Returns:
349
+ List[str]: List of unique values for the specified feature.
350
+ """
351
+ try:
352
+ if feature_name in self._mapped_edge_features:
353
+ raise IndexError(
354
+ "Only non-mapped features can be retrieved uniquely.")
355
+ return ordered_set(self.get_edge_features(feature_name))
356
+ except KeyError:
357
+ raise AttributeError(
358
+ f"Edges do not have a feature called {feature_name}")
359
+
360
+ def add_edge_feature(
361
+ self,
362
+ new_feature_name: str,
363
+ new_feature_vals: FeatureValueType,
364
+ map_from_feature: str = EDGE_PID,
365
+ ) -> None:
366
+ r"""Adds a new feature that corresponds to each unique edge in
367
+ the graph.
368
+
369
+ Args:
370
+ new_feature_name (str): Name to call the new feature.
371
+ new_feature_vals (FeatureValueType): Values to map for that new
372
+ feature.
373
+ map_from_feature (str, optional): Key of feature to map from.
374
+ Size must match the number of feature values.
375
+ Defaults to EDGE_PID.
376
+ """
377
+ if new_feature_name in self.edge_attr:
378
+ raise AttributeError("Features cannot be overridden once created")
379
+ if map_from_feature in self._mapped_edge_features:
380
+ raise AttributeError(
381
+ f"{map_from_feature} is already a feature mapping.")
382
+
383
+ feature_keys = self.get_unique_edge_features(map_from_feature)
384
+ if len(feature_keys) != len(new_feature_vals):
385
+ raise AttributeError(
386
+ f"Expected encodings for {len(feature_keys)} unique features, "
387
+ + f"but got {len(new_feature_vals)} encodings.")
388
+
389
+ if map_from_feature == EDGE_PID:
390
+ self.edge_attr[new_feature_name] = new_feature_vals
391
+ else:
392
+ self.edge_attr[new_feature_name] = MappedFeature(
393
+ name=map_from_feature, values=new_feature_vals)
394
+ self._mapped_edge_features.add(new_feature_name)
395
+
396
+ def get_edge_features(
397
+ self,
398
+ feature_name: str = EDGE_PID,
399
+ pids: Optional[Iterable[str]] = None,
400
+ ) -> List[Any]:
401
+ r"""Get edge feature values for a given set of unique edge ids.
402
+ Returned values are not necessarily unique.
403
+
404
+ Args:
405
+ feature_name (str, optional): Name of feature to fetch.
406
+ Defaults to EDGE_PID.
407
+ pids (Optional[Iterable[str]], optional): Edge ids to fetch
408
+ for. Defaults to None, which fetches all edges.
409
+
410
+ Returns:
411
+ List[Any]: Node features corresponding to the specified ids.
412
+ """
413
+ if feature_name in self._mapped_edge_features:
414
+ values = self.edge_attr[feature_name].values
415
+ else:
416
+ values = self.edge_attr[feature_name]
417
+
418
+ # TODO: torch_geometric.utils.select
419
+ if isinstance(values, torch.Tensor):
420
+ idxs = list(
421
+ self.get_edge_features_iter(feature_name, pids,
422
+ index_only=True))
423
+ return values[idxs]
424
+ return list(self.get_edge_features_iter(feature_name, pids))
425
+
426
+ def get_edge_features_iter(
427
+ self,
428
+ feature_name: str = EDGE_PID,
429
+ pids: Optional[KnowledgeGraphLike] = None,
430
+ index_only: bool = False,
431
+ ) -> Iterator[Any]:
432
+ """Iterator version of get_edge_features. If index_only is True,
433
+ yields indices instead of values.
434
+ """
435
+ if pids is None:
436
+ pids = self.edge_attr[EDGE_PID]
437
+
438
+ if feature_name in self._mapped_edge_features:
439
+ feature_map_info = self.edge_attr[feature_name]
440
+ from_feature_name, to_feature_vals = (
441
+ feature_map_info.name,
442
+ feature_map_info.values,
443
+ )
444
+ from_feature_vals = self.get_unique_edge_features(
445
+ from_feature_name)
446
+ feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}
447
+
448
+ for pid in pids:
449
+ idx = self._edges[pid]
450
+ from_feature_val = self.edge_attr[from_feature_name][idx]
451
+ to_feature_idx = feature_mapping[from_feature_val]
452
+ if index_only:
453
+ yield to_feature_idx
454
+ else:
455
+ yield to_feature_vals[to_feature_idx]
456
+ else:
457
+ for pid in pids:
458
+ idx = self._edges[pid]
459
+ if index_only:
460
+ yield idx
461
+ else:
462
+ yield self.edge_attr[feature_name][idx]
463
+
464
+ def to_triplets(self) -> Iterator[TripletLike]:
465
+ return iter(self.edge_attr[EDGE_PID])
466
+
467
+ def save(self, path: str) -> None:
468
+ if os.path.exists(path):
469
+ shutil.rmtree(path)
470
+ os.makedirs(path, exist_ok=True)
471
+ with open(path + "/edges", "wb") as f:
472
+ pkl.dump(self._edges, f)
473
+ with open(path + "/nodes", "wb") as f:
474
+ pkl.dump(self._nodes, f)
475
+
476
+ with open(path + "/mapped_edges", "wb") as f:
477
+ pkl.dump(self._mapped_edge_features, f)
478
+ with open(path + "/mapped_nodes", "wb") as f:
479
+ pkl.dump(self._mapped_node_features, f)
480
+
481
+ node_attr_path = path + "/node_attr"
482
+ os.makedirs(node_attr_path, exist_ok=True)
483
+ for attr_name, vals in self.node_attr.items():
484
+ torch.save(vals, node_attr_path + f"/{attr_name}.pt")
485
+
486
+ edge_attr_path = path + "/edge_attr"
487
+ os.makedirs(edge_attr_path, exist_ok=True)
488
+ for attr_name, vals in self.edge_attr.items():
489
+ torch.save(vals, edge_attr_path + f"/{attr_name}.pt")
490
+
491
+ @classmethod
492
+ def from_disk(cls, path: str) -> "LargeGraphIndexer":
493
+ indexer = cls(list(), list())
494
+ with open(path + "/edges", "rb") as f:
495
+ indexer._edges = pkl.load(f)
496
+ with open(path + "/nodes", "rb") as f:
497
+ indexer._nodes = pkl.load(f)
498
+
499
+ with open(path + "/mapped_edges", "rb") as f:
500
+ indexer._mapped_edge_features = pkl.load(f)
501
+ with open(path + "/mapped_nodes", "rb") as f:
502
+ indexer._mapped_node_features = pkl.load(f)
503
+
504
+ node_attr_path = path + "/node_attr"
505
+ for fname in os.listdir(node_attr_path):
506
+ full_fname = f"{node_attr_path}/{fname}"
507
+ key = fname.split(".")[0]
508
+ indexer.node_attr[key] = torch.load(full_fname)
509
+
510
+ edge_attr_path = path + "/edge_attr"
511
+ for fname in os.listdir(edge_attr_path):
512
+ full_fname = f"{edge_attr_path}/{fname}"
513
+ key = fname.split(".")[0]
514
+ indexer.edge_attr[key] = torch.load(full_fname)
515
+
516
+ return indexer
517
+
518
+ def to_data(self, node_feature_name: str,
519
+ edge_feature_name: Optional[str] = None) -> Data:
520
+ """Return a Data object containing all the specified node and
521
+ edge features and the graph.
522
+
523
+ Args:
524
+ node_feature_name (str): Feature to use for nodes
525
+ edge_feature_name (Optional[str], optional): Feature to use for
526
+ edges. Defaults to None.
527
+
528
+ Returns:
529
+ Data: Data object containing the specified node and
530
+ edge features and the graph.
531
+ """
532
+ x = torch.Tensor(self.get_node_features(node_feature_name))
533
+ node_id = torch.LongTensor(range(len(x)))
534
+
535
+ edge_index = torch.t(
536
+ torch.LongTensor(self.get_edge_features(EDGE_INDEX)))
537
+
538
+ edge_attr = (self.get_edge_features(edge_feature_name)
539
+ if edge_feature_name is not None else None)
540
+ edge_id = torch.LongTensor(range(len(edge_attr)))
541
+
542
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
543
+ edge_id=edge_id, node_id=node_id)
544
+
545
+ def __eq__(self, value: "LargeGraphIndexer") -> bool:
546
+ eq = True
547
+ eq &= self._nodes == value._nodes
548
+ eq &= self._edges == value._edges
549
+ eq &= self.node_attr.keys() == value.node_attr.keys()
550
+ eq &= self.edge_attr.keys() == value.edge_attr.keys()
551
+ eq &= self._mapped_node_features == value._mapped_node_features
552
+ eq &= self._mapped_edge_features == value._mapped_edge_features
553
+
554
+ for k in self.node_attr:
555
+ eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))
556
+ if isinstance(self.node_attr[k], torch.Tensor):
557
+ eq &= torch.equal(self.node_attr[k], value.node_attr[k])
558
+ else:
559
+ eq &= self.node_attr[k] == value.node_attr[k]
560
+ for k in self.edge_attr:
561
+ eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))
562
+ if isinstance(self.edge_attr[k], torch.Tensor):
563
+ eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])
564
+ else:
565
+ eq &= self.edge_attr[k] == value.edge_attr[k]
566
+ return eq
567
+
568
+
569
+ def get_features_for_triplets_groups(
570
+ indexer: LargeGraphIndexer,
571
+ triplet_groups: Iterable[KnowledgeGraphLike],
572
+ node_feature_name: str = "x",
573
+ edge_feature_name: str = "edge_attr",
574
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
575
+ verbose: bool = False,
576
+ ) -> Iterator[Data]:
577
+ """Given an indexer and a series of triplet groups (like a dataset),
578
+ retrieve the specified node and edge features for each triplet from the
579
+ index.
580
+
581
+ Args:
582
+ indexer (LargeGraphIndexer): Indexer containing desired features
583
+ triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of
584
+ triplets to fetch features for
585
+ node_feature_name (str, optional): Node feature to fetch.
586
+ Defaults to "x".
587
+ edge_feature_name (str, optional): edge feature to fetch.
588
+ Defaults to "edge_attr".
589
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
590
+ Optional preprocessing to perform on triplets.
591
+ Defaults to None.
592
+ verbose (bool, optional): Whether to print progress. Defaults to False.
593
+
594
+ Yields:
595
+ Iterator[Data]: For each triplet group, yield a data object containing
596
+ the unique graph and features from the index.
597
+ """
598
+ if pre_transform is not None:
599
+
600
+ def apply_transform(trips):
601
+ for trip in trips:
602
+ yield pre_transform(tuple(trip))
603
+
604
+ # TODO: Make this safe for large amounts of triplets?
605
+ triplet_groups = (list(apply_transform(triplets))
606
+ for triplets in triplet_groups)
607
+
608
+ node_keys = []
609
+ edge_keys = []
610
+ edge_index = []
611
+
612
+ for triplets in tqdm(triplet_groups, disable=not verbose):
613
+ small_graph_indexer = LargeGraphIndexer.from_triplets(
614
+ triplets, pre_transform=pre_transform)
615
+
616
+ node_keys.append(small_graph_indexer.get_node_features())
617
+ edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets))
618
+ edge_index.append(
619
+ small_graph_indexer.get_edge_features(EDGE_INDEX, triplets))
620
+
621
+ node_feats = indexer.get_node_features(feature_name=node_feature_name,
622
+ pids=chain.from_iterable(node_keys))
623
+ edge_feats = indexer.get_edge_features(feature_name=edge_feature_name,
624
+ pids=chain.from_iterable(edge_keys))
625
+
626
+ last_node_idx, last_edge_idx = 0, 0
627
+ for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index):
628
+ nlen, elen = len(nkeys), len(ekeys)
629
+ x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])
630
+ last_node_idx += len(nkeys)
631
+
632
+ edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +
633
+ elen])
634
+ last_edge_idx += len(ekeys)
635
+
636
+ edge_idx = torch.LongTensor(eidx).T
637
+
638
+ data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
639
+ data_obj[NODE_PID] = node_keys
640
+ data_obj[EDGE_PID] = edge_keys
641
+ data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys]
642
+ data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys]
643
+
644
+ yield data_obj
645
+
646
+
647
+ def get_features_for_triplets(
648
+ indexer: LargeGraphIndexer,
649
+ triplets: KnowledgeGraphLike,
650
+ node_feature_name: str = "x",
651
+ edge_feature_name: str = "edge_attr",
652
+ pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
653
+ verbose: bool = False,
654
+ ) -> Data:
655
+ """For a given set of triplets retrieve a Data object containing the
656
+ unique graph and features from the index.
657
+
658
+ Args:
659
+ indexer (LargeGraphIndexer): Indexer containing desired features
660
+ triplets (KnowledgeGraphLike): Triplets to fetch features for
661
+ node_feature_name (str, optional): Feature to use for node features.
662
+ Defaults to "x".
663
+ edge_feature_name (str, optional): Feature to use for edge features.
664
+ Defaults to "edge_attr".
665
+ pre_transform (Optional[Callable[[TripletLike], TripletLike]]):
666
+ Optional preprocessing function for triplets. Defaults to None.
667
+ verbose (bool, optional): Whether to print progress. Defaults to False.
668
+
669
+ Returns:
670
+ Data: Data object containing the unique graph and features from the
671
+ index for the given triplets.
672
+ """
673
+ gen = get_features_for_triplets_groups(indexer, [triplets],
674
+ node_feature_name,
675
+ edge_feature_name, pre_transform,
676
+ verbose)
677
+ return next(gen)