pyg-nightly 2.7.0.dev20250822__py3-none-any.whl → 2.7.0.dev20250824__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250822
3
+ Version: 2.7.0.dev20250824
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=EzPtCTlUiMKhCY8DpZ-PoeX7k2ZTyBUGvpb8T5aKXFg,2250
1
+ torch_geometric/__init__.py,sha256=MgxbhqEQw85B1XawAzP1xK2jTO9z_QcIYtwhfU88AQQ,2250
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -519,10 +519,10 @@ torch_geometric/profile/profile.py,sha256=cHCY4U0XtyqyKC5u380q6TspsOZ5tGHNXaZsKu
519
519
  torch_geometric/profile/profiler.py,sha256=rfNciRzWDka_BgO6aPFi3cy8mcT4lSgFWy-WfPgI2SI,16891
520
520
  torch_geometric/profile/utils.py,sha256=ynlUVemNJZ6XjJKIkPJNwFPoUyBgVAnchfHBpKOp_HE,5903
521
521
  torch_geometric/sampler/__init__.py,sha256=0h_xJ7CQnlTxF5hUpc81WPQ0QaBtouG8eKK1RzPGA-s,512
522
- torch_geometric/sampler/base.py,sha256=T7RMx14RSlEKlQUkMvR1EzREaXi14VgR5GIwLfvbXzQ,27055
522
+ torch_geometric/sampler/base.py,sha256=0z5e6DYqUGa1-eyJKUGCyUJe6T-DoKBVpOdNxD-dOzM,29225
523
523
  torch_geometric/sampler/hgt_sampler.py,sha256=jizRJyEoz4WBOEELuqdytG2hB3UpVQX7yVPM83kvpfE,2991
524
524
  torch_geometric/sampler/neighbor_sampler.py,sha256=GBnIwPF5pf85wDo_FaMKtMgBlQhzRx5ysiAomMNhCrQ,34092
525
- torch_geometric/sampler/utils.py,sha256=RJtasO6Q7Pp3oYEOWrbf2DEYuSfuKZOsF2I7-eJDnoA,5485
525
+ torch_geometric/sampler/utils.py,sha256=U4XGCEBX6rhmKHpCwYjskNvKNnRTXb9GeVKR1vSWuvA,7250
526
526
  torch_geometric/testing/__init__.py,sha256=m3yp_5UnCAxVgzTFofpiVt0vdbl5GwVAve8WTrAaNxo,1319
527
527
  torch_geometric/testing/asserts.py,sha256=DLC9HnBgFWuTIiQs2OalsQcXGhOVG-e6R99IWhkO32c,4606
528
528
  torch_geometric/testing/data.py,sha256=O1qo8FyNxt6RGf63Ys3eXBfa5RvYydeZLk74szrez3c,2604
@@ -645,7 +645,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
645
645
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
646
646
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
647
647
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
648
- pyg_nightly-2.7.0.dev20250822.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
649
- pyg_nightly-2.7.0.dev20250822.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
650
- pyg_nightly-2.7.0.dev20250822.dist-info/METADATA,sha256=dM0G4qt1GnfQIJxN5QRCrRemxyP3r1iyITMakFRVMAw,64100
651
- pyg_nightly-2.7.0.dev20250822.dist-info/RECORD,,
648
+ pyg_nightly-2.7.0.dev20250824.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
649
+ pyg_nightly-2.7.0.dev20250824.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
650
+ pyg_nightly-2.7.0.dev20250824.dist-info/METADATA,sha256=e414Fyh9TNPZurzgwbAnQfQufV29PLugLmumo4GOIyk,64100
651
+ pyg_nightly-2.7.0.dev20250824.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250822'
34
+ __version__ = '2.7.0.dev20250824'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -11,7 +11,10 @@ import torch
11
11
  from torch import Tensor
12
12
 
13
13
  from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
14
- from torch_geometric.sampler.utils import to_bidirectional
14
+ from torch_geometric.sampler.utils import (
15
+ local_to_global_node_idx,
16
+ to_bidirectional,
17
+ )
15
18
  from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
16
19
  from torch_geometric.utils.mixin import CastMixin
17
20
 
@@ -207,6 +210,29 @@ class SamplerOutput(CastMixin):
207
210
  # API for the expected output of a sampler.
208
211
  metadata: Optional[Any] = None
209
212
 
213
+ @property
214
+ def global_row(self) -> Tensor:
215
+ return local_to_global_node_idx(self.node, self.row)
216
+
217
+ @property
218
+ def global_col(self) -> Tensor:
219
+ return local_to_global_node_idx(self.node, self.col)
220
+
221
+ @property
222
+ def seed_node(self) -> Tensor:
223
+ return local_to_global_node_idx(
224
+ self.node, self.batch) if self.batch is not None else None
225
+
226
+ @property
227
+ def global_orig_row(self) -> Tensor:
228
+ return local_to_global_node_idx(
229
+ self.node, self.orig_row) if self.orig_row is not None else None
230
+
231
+ @property
232
+ def global_orig_col(self) -> Tensor:
233
+ return local_to_global_node_idx(
234
+ self.node, self.orig_col) if self.orig_col is not None else None
235
+
210
236
  def to_bidirectional(
211
237
  self,
212
238
  keep_orig_edges: bool = False,
@@ -294,6 +320,43 @@ class HeteroSamplerOutput(CastMixin):
294
320
  # API for the expected output of a sampler.
295
321
  metadata: Optional[Any] = None
296
322
 
323
+ @property
324
+ def global_row(self) -> Dict[EdgeType, Tensor]:
325
+ return {
326
+ edge_type: local_to_global_node_idx(self.node[edge_type[0]], row)
327
+ for edge_type, row in self.row.items()
328
+ }
329
+
330
+ @property
331
+ def global_col(self) -> Dict[EdgeType, Tensor]:
332
+ return {
333
+ edge_type: local_to_global_node_idx(self.node[edge_type[2]], col)
334
+ for edge_type, col in self.col.items()
335
+ }
336
+
337
+ @property
338
+ def seed_node(self) -> Optional[Dict[NodeType, Tensor]]:
339
+ return {
340
+ node_type: local_to_global_node_idx(self.node[node_type], batch)
341
+ for node_type, batch in self.batch.items()
342
+ } if self.batch is not None else None
343
+
344
+ @property
345
+ def global_orig_row(self) -> Optional[Dict[EdgeType, Tensor]]:
346
+ return {
347
+ edge_type: local_to_global_node_idx(self.node[edge_type[0]],
348
+ orig_row)
349
+ for edge_type, orig_row in self.orig_row.items()
350
+ } if self.orig_row is not None else None
351
+
352
+ @property
353
+ def global_orig_col(self) -> Optional[Dict[EdgeType, Tensor]]:
354
+ return {
355
+ edge_type: local_to_global_node_idx(self.node[edge_type[2]],
356
+ orig_col)
357
+ for edge_type, orig_col in self.orig_col.items()
358
+ } if self.orig_col is not None else None
359
+
297
360
  def to_bidirectional(
298
361
  self,
299
362
  keep_orig_edges: bool = False,
@@ -160,3 +160,48 @@ def remap_keys(
160
160
  k if k in exclude else mapping.get(k, k): v
161
161
  for k, v in inputs.items()
162
162
  }
163
+
164
+
165
+ def local_to_global_node_idx(node_values: Tensor,
166
+ local_indices: Tensor) -> Tensor:
167
+ """Convert a tensor of indices referring to elements in the node_values
168
+ tensor to their values.
169
+
170
+ Args:
171
+ node_values (Tensor): The node values. (num_nodes, feature_dim)
172
+ local_indices (Tensor): The local indices. (num_indices)
173
+
174
+ Returns:
175
+ Tensor: The values of the node_values tensor at the local indices.
176
+ (num_indices, feature_dim)
177
+ """
178
+ return torch.index_select(node_values, dim=0, index=local_indices)
179
+
180
+
181
+ def global_to_local_node_idx(node_values: Tensor,
182
+ local_values: Tensor) -> Tensor:
183
+ """Converts a tensor of values that are contained in the node_values
184
+ tensor to their indices in that tensor.
185
+
186
+ Args:
187
+ node_values (Tensor): The node values. (num_nodes, feature_dim)
188
+ local_values (Tensor): The local values. (num_indices, feature_dim)
189
+
190
+ Returns:
191
+ Tensor: The indices of the local values in the node_values tensor.
192
+ (num_indices)
193
+ """
194
+ if node_values.dim() == 1:
195
+ node_values = node_values.unsqueeze(1)
196
+ if local_values.dim() == 1:
197
+ local_values = local_values.unsqueeze(1)
198
+ node_values_expand = node_values.unsqueeze(-1).expand(
199
+ *node_values.shape,
200
+ local_values.shape[0]) # (num_nodes, feature_dim, num_indices)
201
+ local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand(
202
+ *node_values_expand.shape) # (num_nodes, feature_dim, num_indices)
203
+ idx_match = torch.all(node_values_expand == local_values_expand,
204
+ dim=1).nonzero() # (num_indices, 2)
205
+ sort_idx = torch.argsort(idx_match[:, 1])
206
+
207
+ return idx_match[:, 0][sort_idx]