pyg-nightly 2.7.0.dev20241210__py3-none-any.whl → 2.7.0.dev20241212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20241210
3
+ Version: 2.7.0.dev20241212
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=_D9NOV_9zD7KA-qzPTAx93bn3oq7AKJ5GXk_j14e6tU,1904
1
+ torch_geometric/__init__.py,sha256=rM2co1RdbpOI7hq4w_6b4AmFYqhWfKQFujUwtX-vY2I,1904
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -18,7 +18,7 @@ torch_geometric/logging.py,sha256=HmHHLiCcM64k-6UYNOSfXPIeSGNAyiGGcn8cD8tlyuQ,85
18
18
  torch_geometric/resolver.py,sha256=fn-_6mCpI2xv7eDZnIFcYrHOn0IrwbkWFLDb9laQrWI,1270
19
19
  torch_geometric/seed.py,sha256=MJLbVwpb9i8mK3oi32sS__Cq-dRq_afTeoOL_HoA9ko,372
20
20
  torch_geometric/template.py,sha256=rqjDWgcSAgTCiV4bkOjWRPaO4PpUdC_RXigzxxBqAu8,1060
21
- torch_geometric/typing.py,sha256=0pxCLue4iqqFC-k5ByqAeyg2mogtWXqgtod3ZOEMq1A,13933
21
+ torch_geometric/typing.py,sha256=PO6jvRjcGkZoMPBEo9GANZN5gUqHV1YEbUBbbdaX1oE,14331
22
22
  torch_geometric/warnings.py,sha256=t114CbkrmiqkXaavx5g7OO52dLdktf-U__B5QqYIQvI,413
23
23
  torch_geometric/contrib/__init__.py,sha256=0pWkmXfZtbdr-AKwlii5LTFggTEH-MCrSKpZxrtPlVs,352
24
24
  torch_geometric/contrib/datasets/__init__.py,sha256=lrGnWsEiJf5zsBRmshGZZFN_uYR2ezDjbj9n9nCpvtk,23
@@ -43,7 +43,7 @@ torch_geometric/data/graph_store.py,sha256=oFrLDNP5hKf3HWWsFsjcamx5vLIEk8JnLjuGp
43
43
  torch_geometric/data/hetero_data.py,sha256=q0L3bENyEvo_BGLPwZPVzh730Aak6sQ7yXoawPgM72E,47982
44
44
  torch_geometric/data/hypergraph_data.py,sha256=33hsXW25Yz4Ju8mKajYinZOrkqrUi1SqThG7MlOOYNM,8294
45
45
  torch_geometric/data/in_memory_dataset.py,sha256=F35hU9Dw3qiJUL5E1CCAfq-1xrlUMstXBmQVEQdtJ1I,13403
46
- torch_geometric/data/large_graph_indexer.py,sha256=JqozKbn5C-jLq2uydeImWqihvBRg8nl5Al55V5s53aw,25433
46
+ torch_geometric/data/large_graph_indexer.py,sha256=eor7F98kPDlrs1v0qypinVbhQr8wnw8mAfsLNiNmEwY,25390
47
47
  torch_geometric/data/makedirs.py,sha256=6uOv4y34i947cm4rv7Aj2_YZBq-EOsyPKnlGA188YSw,463
48
48
  torch_geometric/data/on_disk_dataset.py,sha256=77om-e6kzcpBb77kf7um1xY8-yHmQaao_6R7I-3NwHk,6629
49
49
  torch_geometric/data/remote_backend_utils.py,sha256=Rzpq1PczXuHhUscrFtIAL6dua6pMehSJlXG7yEsrrrg,4503
@@ -262,7 +262,7 @@ torch_geometric/io/ply.py,sha256=NdeTtr79vJ1HS37ZV2N61EUmA5NGJd2I6cUj1Pg7Ypg,489
262
262
  torch_geometric/io/sdf.py,sha256=H2PC6dSW9Kncc1ulb0UN0JnTRT93NY2fY8lf6K4hb50,1165
263
263
  torch_geometric/io/tu.py,sha256=-v5Ago7DfmGTRBtB5RZFvmv4XpLnKKnk-NOnxlHtB_c,4881
264
264
  torch_geometric/io/txt_array.py,sha256=LDeX2qtlNKW-kVe-wpnskMwAdXQp1jVCGQnrJce7Smg,910
265
- torch_geometric/loader/__init__.py,sha256=o0wC0Gvv4rewpZU_YeVaJZCCZZJQG2v8MfZhjocvKp8,1896
265
+ torch_geometric/loader/__init__.py,sha256=DJrdCD1A5PuBYRSgxFbZU9GTBStNuKngqkUV1oEfQQ4,1971
266
266
  torch_geometric/loader/base.py,sha256=ataIwNEYL0px3CN3LJEgXIVTRylDHB6-yBFXXuX2JN0,1615
267
267
  torch_geometric/loader/cache.py,sha256=S65heO3YTyUPbttqizCNtKPHIoAw5iHRpbvw6KlXmok,2106
268
268
  torch_geometric/loader/cluster.py,sha256=eMNxVkvZt5oQ_gJRgmWm1NBX7zU2tZI_BPaXeB0wuyk,13465
@@ -281,7 +281,7 @@ torch_geometric/loader/neighbor_loader.py,sha256=vnLn_RhBKTux5h8pi0vzj0d7JPoOpLA
281
281
  torch_geometric/loader/neighbor_sampler.py,sha256=mraVFXIIGctYot4Xr2VOAhCKAOQyW2gP9KROf7g6tcc,8497
282
282
  torch_geometric/loader/node_loader.py,sha256=g_kV5N0tO6eMSFPc5fdbzfHr4COAeKVJi7FEq52f4zc,11848
283
283
  torch_geometric/loader/prefetch.py,sha256=p1mr54TL4nx3Ea0fBy0JulGYJ8Hq4_9rsiNioZsIW-4,3211
284
- torch_geometric/loader/rag_loader.py,sha256=nwswemzYL4wCKljXqsxMDg07x6PkLU_kgAkNFj5TwUY,4555
284
+ torch_geometric/loader/rag_loader.py,sha256=8hBmccelYOf7HJfdfKLNCpOJYnJ9bFAHeKLkTOuc4CM,4642
285
285
  torch_geometric/loader/random_node_loader.py,sha256=rCmRXYv70SPxBo-Oh049eFEWEZDV7FmlRPzmjcoirXQ,2196
286
286
  torch_geometric/loader/shadow.py,sha256=_hCspYf9SlJYX0lqEjxFec9e9t1iMScNThOoWR1wQGM,4173
287
287
  torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_yndivXt7Z6_w62zI,2248
@@ -629,6 +629,6 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
629
629
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
630
630
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
631
631
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
632
- pyg_nightly-2.7.0.dev20241210.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
- pyg_nightly-2.7.0.dev20241210.dist-info/METADATA,sha256=oi3nlV5IlVoT1QIIXAQzf2sVt-jXU35Kb4OubJdoAUo,62979
634
- pyg_nightly-2.7.0.dev20241210.dist-info/RECORD,,
632
+ pyg_nightly-2.7.0.dev20241212.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
633
+ pyg_nightly-2.7.0.dev20241212.dist-info/METADATA,sha256=SgH5YIFmqInilgeBdaYxaHN1mejOmP0RzY6Pfu90xF8,62979
634
+ pyg_nightly-2.7.0.dev20241212.dist-info/RECORD,,
@@ -30,7 +30,7 @@ from .lazy_loader import LazyLoader
30
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
32
 
33
- __version__ = '2.7.0.dev20241210'
33
+ __version__ = '2.7.0.dev20241212'
34
34
 
35
35
  __all__ = [
36
36
  'Index',
@@ -7,7 +7,6 @@ from typing import (
7
7
  Any,
8
8
  Callable,
9
9
  Dict,
10
- Hashable,
11
10
  Iterable,
12
11
  Iterator,
13
12
  List,
@@ -25,12 +24,13 @@ from tqdm import tqdm
25
24
  from torch_geometric.data import Data
26
25
  from torch_geometric.typing import WITH_PT24
27
26
 
28
- TripletLike = Tuple[Hashable, Hashable, Hashable]
27
+ # Could be any hashable type
28
+ TripletLike = Tuple[str, str, str]
29
29
 
30
30
  KnowledgeGraphLike = Iterable[TripletLike]
31
31
 
32
32
 
33
- def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
33
+ def ordered_set(values: Iterable[str]) -> List[str]:
34
34
  return list(dict.fromkeys(values))
35
35
 
36
36
 
@@ -70,13 +70,13 @@ if WITH_PT24:
70
70
 
71
71
 
72
72
  class LargeGraphIndexer:
73
- """For a dataset that consists of mulitiple subgraphs that are assumed to
73
+ """For a dataset that consists of multiple subgraphs that are assumed to
74
74
  be part of a much larger graph, collate the values into a large graph store
75
75
  to save resources.
76
76
  """
77
77
  def __init__(
78
78
  self,
79
- nodes: Iterable[Hashable],
79
+ nodes: Iterable[str],
80
80
  edges: KnowledgeGraphLike,
81
81
  node_attr: Optional[Dict[str, List[Any]]] = None,
82
82
  edge_attr: Optional[Dict[str, List[Any]]] = None,
@@ -85,7 +85,7 @@ class LargeGraphIndexer:
85
85
  by id. Not meant to be used directly.
86
86
 
87
87
  Args:
88
- nodes (Iterable[Hashable]): Node ids in the graph.
88
+ nodes (Iterable[str]): Node ids in the graph.
89
89
  edges (KnowledgeGraphLike): Edge ids in the graph.
90
90
  node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
91
91
  attribute name and list of their values in order of unique node
@@ -94,7 +94,7 @@ class LargeGraphIndexer:
94
94
  attribute name and list of their values in order of unique edge
95
95
  ids. Defaults to None.
96
96
  """
97
- self._nodes: Dict[Hashable, int] = dict()
97
+ self._nodes: Dict[str, int] = dict()
98
98
  self._edges: Dict[TripletLike, int] = dict()
99
99
 
100
100
  self._mapped_node_features: Set[str] = set()
@@ -201,7 +201,7 @@ class LargeGraphIndexer:
201
201
  index.
202
202
 
203
203
  Args:
204
- graphs (Iterable[&quot;LargeGraphIndexer&quot;]): Indices to be
204
+ graphs (Iterable[LargeGraphIndexer]): Indices to be
205
205
  combined.
206
206
 
207
207
  Returns:
@@ -212,8 +212,8 @@ class LargeGraphIndexer:
212
212
  trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
213
213
  return cls.from_triplets(trips)
214
214
 
215
- def get_unique_node_features(
216
- self, feature_name: str = NODE_PID) -> List[Hashable]:
215
+ def get_unique_node_features(self,
216
+ feature_name: str = NODE_PID) -> List[str]:
217
217
  r"""Get all the unique values for a specific node attribute.
218
218
 
219
219
  Args:
@@ -221,7 +221,7 @@ class LargeGraphIndexer:
221
221
  Defaults to NODE_PID.
222
222
 
223
223
  Returns:
224
- List[Hashable]: List of unique values for the specified feature.
224
+ List[str]: List of unique values for the specified feature.
225
225
  """
226
226
  try:
227
227
  if feature_name in self._mapped_node_features:
@@ -272,7 +272,7 @@ class LargeGraphIndexer:
272
272
  def get_node_features(
273
273
  self,
274
274
  feature_name: str = NODE_PID,
275
- pids: Optional[Iterable[Hashable]] = None,
275
+ pids: Optional[Iterable[str]] = None,
276
276
  ) -> List[Any]:
277
277
  r"""Get node feature values for a given set of unique node ids.
278
278
  Returned values are not necessarily unique.
@@ -280,7 +280,7 @@ class LargeGraphIndexer:
280
280
  Args:
281
281
  feature_name (str, optional): Name of feature to fetch. Defaults
282
282
  to NODE_PID.
283
- pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
283
+ pids (Optional[Iterable[str]], optional): Node ids to fetch
284
284
  for. Defaults to None, which fetches all nodes.
285
285
 
286
286
  Returns:
@@ -302,7 +302,7 @@ class LargeGraphIndexer:
302
302
  def get_node_features_iter(
303
303
  self,
304
304
  feature_name: str = NODE_PID,
305
- pids: Optional[Iterable[Hashable]] = None,
305
+ pids: Optional[Iterable[str]] = None,
306
306
  index_only: bool = False,
307
307
  ) -> Iterator[Any]:
308
308
  """Iterator version of get_node_features. If index_only is True,
@@ -337,8 +337,8 @@ class LargeGraphIndexer:
337
337
  else:
338
338
  yield self.node_attr[feature_name][idx]
339
339
 
340
- def get_unique_edge_features(
341
- self, feature_name: str = EDGE_PID) -> List[Hashable]:
340
+ def get_unique_edge_features(self,
341
+ feature_name: str = EDGE_PID) -> List[str]:
342
342
  r"""Get all the unique values for a specific edge attribute.
343
343
 
344
344
  Args:
@@ -346,7 +346,7 @@ class LargeGraphIndexer:
346
346
  Defaults to EDGE_PID.
347
347
 
348
348
  Returns:
349
- List[Hashable]: List of unique values for the specified feature.
349
+ List[str]: List of unique values for the specified feature.
350
350
  """
351
351
  try:
352
352
  if feature_name in self._mapped_edge_features:
@@ -396,7 +396,7 @@ class LargeGraphIndexer:
396
396
  def get_edge_features(
397
397
  self,
398
398
  feature_name: str = EDGE_PID,
399
- pids: Optional[Iterable[Hashable]] = None,
399
+ pids: Optional[Iterable[str]] = None,
400
400
  ) -> List[Any]:
401
401
  r"""Get edge feature values for a given set of unique edge ids.
402
402
  Returned values are not necessarily unique.
@@ -404,7 +404,7 @@ class LargeGraphIndexer:
404
404
  Args:
405
405
  feature_name (str, optional): Name of feature to fetch.
406
406
  Defaults to EDGE_PID.
407
- pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
407
+ pids (Optional[Iterable[str]], optional): Edge ids to fetch
408
408
  for. Defaults to None, which fetches all edges.
409
409
 
410
410
  Returns:
@@ -22,7 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
22
22
  from .prefetch import PrefetchLoader
23
23
  from .cache import CachedLoader
24
24
  from .mixin import AffinityMixin
25
- from .rag_loader import RAGQueryLoader
25
+ from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
26
26
 
27
27
  __all__ = classes = [
28
28
  'DataLoader',
@@ -52,6 +52,8 @@ __all__ = classes = [
52
52
  'CachedLoader',
53
53
  'AffinityMixin',
54
54
  'RAGQueryLoader',
55
+ 'RAGFeatureStore',
56
+ 'RAGGraphStore'
55
57
  ]
56
58
 
57
59
  RandomNodeSampler = deprecated(
@@ -7,7 +7,7 @@ from torch_geometric.typing import InputEdges, InputNodes
7
7
 
8
8
 
9
9
  class RAGFeatureStore(Protocol):
10
- """Feature store for remote GNN RAG backend."""
10
+ """Feature store template for remote GNN RAG backend."""
11
11
  @abstractmethod
12
12
  def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
13
13
  """Makes a comparison between the query and all the nodes to get all
@@ -33,7 +33,7 @@ class RAGFeatureStore(Protocol):
33
33
 
34
34
 
35
35
  class RAGGraphStore(Protocol):
36
- """Graph store for remote GNN RAG backend."""
36
+ """Graph store template for remote GNN RAG backend."""
37
37
  @abstractmethod
38
38
  def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
39
39
  **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
@@ -52,6 +52,7 @@ class RAGGraphStore(Protocol):
52
52
 
53
53
 
54
54
  class RAGQueryLoader:
55
+ """Loader meant for making RAG queries from a remote backend."""
55
56
  def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
56
57
  local_filter: Optional[Callable[[Data, Any], Data]] = None,
57
58
  seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
torch_geometric/typing.py CHANGED
@@ -307,6 +307,8 @@ class EdgeTypeStr(str):
307
307
  r"""A helper class to construct serializable edge types by merging an edge
308
308
  type tuple into a single string.
309
309
  """
310
+ edge_type: tuple[str, str, str]
311
+
310
312
  def __new__(cls, *args: Any) -> 'EdgeTypeStr':
311
313
  if isinstance(args[0], (list, tuple)):
312
314
  # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
@@ -314,27 +316,34 @@ class EdgeTypeStr(str):
314
316
 
315
317
  if len(args) == 1 and isinstance(args[0], str):
316
318
  arg = args[0] # An edge type string was passed.
319
+ edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
320
+ if len(edge_type) != 3:
321
+ raise ValueError(f"Cannot convert the edge type '{arg}' to a "
322
+ f"tuple since it holds invalid characters")
317
323
 
318
324
  elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
319
325
  # A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
320
- arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1]))
326
+ edge_type = (args[0], DEFAULT_REL, args[1])
327
+ arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
321
328
 
322
329
  elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
323
330
  # A `(src, rel, dst)` edge type was passed:
331
+ edge_type = tuple(args)
324
332
  arg = EDGE_TYPE_STR_SPLIT.join(args)
325
333
 
326
334
  else:
327
335
  raise ValueError(f"Encountered invalid edge type '{args}'")
328
336
 
329
- return str.__new__(cls, arg)
337
+ out = str.__new__(cls, arg)
338
+ out.edge_type = edge_type # type: ignore
339
+ return out
330
340
 
331
341
  def to_tuple(self) -> EdgeType:
332
342
  r"""Returns the original edge type."""
333
- out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
334
- if len(out) != 3:
343
+ if len(self.edge_type) != 3:
335
344
  raise ValueError(f"Cannot convert the edge type '{self}' to a "
336
345
  f"tuple since it holds invalid characters")
337
- return out
346
+ return self.edge_type
338
347
 
339
348
 
340
349
  # There exist some short-cuts to query edge-types (given that the full triplet