pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,188 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from tqdm import tqdm
8
+
9
+
10
+ class PoolingStrategy(Enum):
11
+ MEAN = 'mean'
12
+ LAST = 'last'
13
+ CLS = 'cls'
14
+ LAST_HIDDEN_STATE = 'last_hidden_state'
15
+
16
+
17
+ class SentenceTransformer(torch.nn.Module):
18
+ r"""A wrapper around a Sentence-Transformer from HuggingFace.
19
+
20
+ Args:
21
+ model_name (str): The HuggingFace model name, *e.g.*, :obj:`"BERT"`.
22
+ pooling_strategy (str, optional): The pooling strategy to use
23
+ for generating node embeddings. (default: :obj:`"mean"`)
24
+ """
25
+ def __init__(
26
+ self,
27
+ model_name: str,
28
+ pooling_strategy: Union[PoolingStrategy, str] = 'mean',
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.model_name = model_name
33
+ self.pooling_strategy = PoolingStrategy(pooling_strategy)
34
+
35
+ from transformers import AutoModel, AutoTokenizer
36
+
37
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ self.model = AutoModel.from_pretrained(model_name)
39
+ if self.tokenizer.pad_token is None:
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+
42
+ # Maximum sequence length from the model configuration (e.g. 8192 for
43
+ # models like ModernBERT)
44
+ self.max_seq_length = self.model.config.max_position_embeddings
45
+ """
46
+ Some models define a max sequence length in their configuration. Others
47
+ only in the tokenizer. This is a hacky heuristic to find the max
48
+ sequence length that works for the model.
49
+ """
50
+ probe_tokens = self.tokenizer("hacky heuristic", padding='max_length',
51
+ return_tensors='pt')
52
+ self.max_seq_length = min(self.max_seq_length,
53
+ probe_tokens.input_ids.shape[1])
54
+
55
+ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
56
+ out = self.model(input_ids=input_ids, attention_mask=attention_mask)
57
+
58
+ emb = out[0] # First element contains all token embeddings.
59
+ if self.pooling_strategy == PoolingStrategy.MEAN:
60
+ emb = mean_pooling(emb, attention_mask)
61
+ elif self.pooling_strategy == PoolingStrategy.LAST:
62
+ emb = last_pooling(emb, attention_mask)
63
+ elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
64
+ emb = out.last_hidden_state
65
+ else:
66
+ assert self.pooling_strategy == PoolingStrategy.CLS
67
+ emb = emb[:, 0, :]
68
+
69
+ emb = F.normalize(emb, p=2, dim=1)
70
+ return emb
71
+
72
+ def get_input_ids(
73
+ self,
74
+ text: List[str],
75
+ batch_size: Optional[int] = None,
76
+ output_device: Optional[Union[torch.device, str]] = None,
77
+ ) -> Tensor:
78
+ is_empty = len(text) == 0
79
+ text = ['dummy'] if is_empty else text
80
+
81
+ batch_size = len(text) if batch_size is None else batch_size
82
+
83
+ input_ids: List[Tensor] = []
84
+ attention_masks: List[Tensor] = []
85
+ for start in range(0, len(text), batch_size):
86
+ token = self.tokenizer(
87
+ text[start:start + batch_size],
88
+ padding=True,
89
+ truncation=True,
90
+ return_tensors='pt',
91
+ max_length=self.max_seq_length,
92
+ )
93
+ input_ids.append(token.input_ids.to(self.device))
94
+ attention_masks.append(token.attention_mask.to(self.device))
95
+
96
+ def _out(x: List[Tensor]) -> Tensor:
97
+ out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
98
+ out = out[:0] if is_empty else out
99
+ return out.to(output_device)
100
+
101
+ return _out(input_ids), _out(attention_masks)
102
+
103
+ @property
104
+ def device(self) -> torch.device:
105
+ return next(iter(self.model.parameters())).device
106
+
107
+ @torch.no_grad()
108
+ def encode(
109
+ self,
110
+ text: List[str],
111
+ batch_size: Optional[int] = None,
112
+ output_device: Optional[Union[torch.device, str]] = None,
113
+ verbose=False,
114
+ ) -> Tensor:
115
+ r"""Main function for users. Converts strings to embeddings.
116
+
117
+ Args:
118
+ text (List[str]): List of strings to embed.
119
+ batch_size (int, optional): How many strings to process.
120
+ Defaults to processing all at once, but this may lead to
121
+ OOM errors. (default: obj:`None`)
122
+ output_device (Union[torch.device, str], optional):
123
+ By default outputs cpu pytorch tensor, but can choose
124
+ to output to specific cuda devices. (default: obj:`None`)
125
+ verbose (bool, optional): Controls the verbosity of outputs.
126
+ (default: obj:`False`)
127
+ """
128
+ is_empty = len(text) == 0
129
+ text = ['dummy'] if is_empty else text
130
+
131
+ batch_size = len(text) if batch_size is None else batch_size
132
+
133
+ embs: List[Tensor] = []
134
+ loader = range(0, len(text), batch_size)
135
+ if verbose:
136
+ loader = tqdm(
137
+ loader, desc="Encoding " + str(len(text)) +
138
+ " strings w/ SentenceTransformer")
139
+ for start in loader:
140
+ token = self.tokenizer(
141
+ text[start:start + batch_size],
142
+ padding=True,
143
+ truncation=True,
144
+ return_tensors='pt',
145
+ max_length=self.max_seq_length,
146
+ )
147
+ try:
148
+ emb = self(
149
+ input_ids=token.input_ids.to(self.device),
150
+ attention_mask=token.attention_mask.to(self.device),
151
+ ).to(output_device)
152
+
153
+ embs.append(emb)
154
+ except: # noqa
155
+ # fallback to using CPU for huge strings that cause OOMs
156
+ print("Sentence Transformer failed on cuda, trying w/ cpu...")
157
+ previous_device = self.device
158
+ self.model = self.model.to("cpu")
159
+ emb = self(
160
+ input_ids=token.input_ids.to(self.device),
161
+ attention_mask=token.attention_mask.to(self.device),
162
+ ).to(output_device)
163
+
164
+ embs.append(emb)
165
+ self.model = self.model.to(previous_device)
166
+
167
+ out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
168
+ out = out[:0] if is_empty else out
169
+ return out
170
+
171
+ def __repr__(self) -> str:
172
+ return f'{self.__class__.__name__}(model_name={self.model_name})'
173
+
174
+
175
+ def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
176
+ mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
177
+ return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
178
+
179
+
180
+ def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
181
+ # Check whether language model uses left padding,
182
+ # which is always used for decoder LLMs
183
+ left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
184
+ if left_padding:
185
+ return emb[:, -1]
186
+
187
+ seq_indices = attention_mask.sum(dim=1) - 1
188
+ return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]
@@ -0,0 +1,353 @@
1
+ import os
2
+ import time
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.multiprocessing as mp
7
+
8
+ CLIENT_INITD = False
9
+
10
+ CLIENT = None
11
+ GLOBAL_NIM_KEY = ""
12
+ SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Separate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph." # noqa
13
+
14
+
15
+ class TXT2KG():
16
+ """A class to convert text data into a Knowledge Graph (KG) format.
17
+ Uses NVIDIA NIMs + Prompt engineering by default.
18
+ Default model `nvidia/llama-3.1-nemotron-70b-instruct`
19
+ is on par or better than GPT4o in benchmarks.
20
+ We need a high quality model to ensure high quality KG.
21
+ Otherwise we have garbage in garbage out for the rest of the
22
+ GNN+LLM RAG pipeline.
23
+
24
+ Use local_lm flag for local debugging/dev. You still need to be able to
25
+ inference a 14B param LLM, 'VAGOsolutions/SauerkrautLM-v2-14b-DPO'.
26
+ Smaller LLMs did not work at all in testing.
27
+ Note this 14B model requires a considerable amount of GPU memory.
28
+ See examples/llm/txt2kg_rag.py for an example.
29
+
30
+ Args:
31
+ NVIDIA_NIM_MODEL : str, optional
32
+ The name of the NVIDIA NIM model to use.
33
+ (default: "nvidia/llama-3.1-nemotron-70b-instruct").
34
+ NVIDIA_API_KEY : str, optional
35
+ The API key for accessing NVIDIA's NIM models (default: "").
36
+ ENDPOINT_URL : str, optional
37
+ The URL hosting your model, in case you are not using
38
+ the public NIM.
39
+ (default: "https://integrate.api.nvidia.com/v1").
40
+ local_LM : bool, optional
41
+ A flag indicating whether a local Language Model (LM)
42
+ should be used. This uses HuggingFace and will be slower
43
+ than deploying your own private NIM endpoint. This flag
44
+ is mainly recommended for dev/debug.
45
+ (default: False).
46
+ chunk_size : int, optional
47
+ The size of the chunks in which the text data is processed
48
+ (default: 512).
49
+ """
50
+ def __init__(
51
+ self,
52
+ NVIDIA_NIM_MODEL: Optional[
53
+ str] = "nvidia/llama-3.1-nemotron-70b-instruct",
54
+ NVIDIA_API_KEY: Optional[str] = "",
55
+ ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
56
+ local_LM: bool = False,
57
+ chunk_size: int = 512,
58
+ ) -> None:
59
+ self.local_LM = local_LM
60
+ # Initialize the local LM flag and the NIM model info accordingly
61
+ if self.local_LM:
62
+ # If using a local LM, set the initd_LM flag to False
63
+ self.initd_LM = False
64
+ else:
65
+ # If not using a local LM, store the provided NIM model info
66
+ self.NVIDIA_API_KEY = NVIDIA_API_KEY
67
+ self.NIM_MODEL = NVIDIA_NIM_MODEL
68
+ self.ENDPOINT_URL = ENDPOINT_URL
69
+
70
+ # Set the chunk size for processing text data
71
+ self.chunk_size = chunk_size
72
+
73
+ # Initialize counters and storage for parsing results
74
+ self.doc_id_counter = 0
75
+ self.relevant_triples = {}
76
+ self.total_chars_parsed = 0
77
+ self.time_to_parse = 0.0
78
+
79
+ def save_kg(self, path: str) -> None:
80
+ """Saves the relevant triples in the knowledge graph (KG) to a file.
81
+
82
+ Args:
83
+ path (str): The file path where the KG will be saved.
84
+
85
+ Returns:
86
+ None
87
+ """
88
+ torch.save(self.relevant_triples, path)
89
+
90
+ def _chunk_to_triples_str_local(self, txt: str) -> str:
91
+ # call LLM on text
92
+ chunk_start_time = time.time()
93
+ if not self.initd_LM:
94
+ from torch_geometric.nn.nlp import LLM
95
+ LM_name = "VAGOsolutions/SauerkrautLM-v2-14b-DPO"
96
+ self.model = LLM(LM_name).eval()
97
+ self.initd_LM = True
98
+ out_str = self.model.inference(question=[txt + '\n' + SYSTEM_PROMPT],
99
+ max_tokens=self.chunk_size)[0]
100
+ # for debug
101
+ self.total_chars_parsed += len(txt)
102
+ self.time_to_parse += round(time.time() - chunk_start_time, 2)
103
+ self.avg_chars_parsed_per_sec = self.total_chars_parsed / self.time_to_parse # noqa
104
+ return out_str
105
+
106
+ def add_doc_2_KG(
107
+ self,
108
+ txt: str,
109
+ QA_pair: Optional[Tuple[str, str]] = None,
110
+ ) -> None:
111
+ """Add a document to the Knowledge Graph (KG).
112
+
113
+ Args:
114
+ txt (str): The text to extract triples from.
115
+ QA_pair (Tuple[str, str]], optional):
116
+ A QA pair to associate with the extracted triples.
117
+ Useful for downstream evaluation.
118
+
119
+ Returns:
120
+ - None
121
+ """
122
+ if not self.local_LM:
123
+ # Ensure NVIDIA_API_KEY is set before proceeding
124
+ assert self.NVIDIA_API_KEY != '', \
125
+ "Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True"
126
+ if QA_pair:
127
+ # QA_pairs should be unique keys, check if already exists in KG
128
+ if QA_pair in self.relevant_triples.keys():
129
+ print("Warning: QA_Pair was already added to the set")
130
+ print("Q=", QA_pair[0])
131
+ print("A=", QA_pair[1])
132
+ print("Previously parsed triples=",
133
+ self.relevant_triples[QA_pair])
134
+ print("Skipping...")
135
+ key = QA_pair
136
+ else:
137
+ # If no QA_pair, use the current doc_id_counter as the key
138
+ key = self.doc_id_counter
139
+
140
+ # Handle empty text (context-less QA pairs)
141
+ if txt == "":
142
+ self.relevant_triples[key] = []
143
+ else:
144
+ # Chunk the text into smaller pieces for processing
145
+ chunks = _chunk_text(txt, chunk_size=self.chunk_size)
146
+
147
+ if self.local_LM:
148
+ # For debugging purposes...
149
+ # process chunks sequentially on the local LM
150
+ self.relevant_triples[key] = _llm_then_python_parse(
151
+ chunks, _parse_n_check_triples,
152
+ self._chunk_to_triples_str_local)
153
+ else:
154
+ # Process chunks in parallel using multiple processes
155
+ num_procs = min(len(chunks), _get_num_procs())
156
+ meta_chunk_size = int(len(chunks) / num_procs)
157
+ in_chunks_per_proc = {
158
+ j:
159
+ chunks[j *
160
+ meta_chunk_size:min((j + 1) *
161
+ meta_chunk_size, len(chunks))]
162
+ for j in range(num_procs)
163
+ }
164
+ for _retry_j in range(5):
165
+ try:
166
+ for _retry_i in range(200):
167
+ try:
168
+ # Spawn multiple processes
169
+ # process chunks in parallel
170
+ mp.spawn(
171
+ _multiproc_helper,
172
+ args=(in_chunks_per_proc,
173
+ _parse_n_check_triples,
174
+ _chunk_to_triples_str_cloud,
175
+ self.NVIDIA_API_KEY, self.NIM_MODEL,
176
+ self.ENDPOINT_URL), nprocs=num_procs)
177
+ break
178
+ except: # noqa
179
+ # keep retrying...
180
+ # txt2kg is costly -> stoppage is costly
181
+ pass
182
+
183
+ # Collect the results from each process
184
+ self.relevant_triples[key] = []
185
+ for rank in range(num_procs):
186
+ self.relevant_triples[key] += torch.load(
187
+ "/tmp/outs_for_proc_" + str(rank))
188
+ os.remove("/tmp/outs_for_proc_" + str(rank))
189
+ break
190
+ except: # noqa
191
+ pass
192
+ # Increment the doc_id_counter for the next document
193
+ self.doc_id_counter += 1
194
+
195
+
196
+ known_reasoners = [
197
+ "llama-3.1-nemotron-ultra-253b-v1",
198
+ "kimi-k2-instruct",
199
+ "nemotron-super-49b-v1_5",
200
+ "gpt-oss",
201
+ ]
202
+
203
+
204
+ def _chunk_to_triples_str_cloud(
205
+ txt: str, GLOBAL_NIM_KEY='',
206
+ NIM_MODEL="nvidia/llama-3.1-nemotron-ultra-253b-v1",
207
+ ENDPOINT_URL="https://integrate.api.nvidia.com/v1",
208
+ post_text=SYSTEM_PROMPT) -> str:
209
+ global CLIENT_INITD
210
+ if not CLIENT_INITD:
211
+ # We use NIMs since most PyG users may not be able to run a 70B+ model
212
+ try:
213
+ from openai import OpenAI
214
+ except ImportError:
215
+ quit(
216
+ "Failed to import `openai` package, please install it and rerun the script" # noqa
217
+ )
218
+ global CLIENT
219
+ CLIENT = OpenAI(base_url=ENDPOINT_URL, api_key=GLOBAL_NIM_KEY)
220
+ CLIENT_INITD = True
221
+ txt_input = txt
222
+ if post_text != "":
223
+ txt_input += '\n' + post_text
224
+ messages = []
225
+ if any([model_name_str in NIM_MODEL
226
+ for model_name_str in known_reasoners]):
227
+ messages.append({"role": "system", "content": "detailed thinking on"})
228
+ messages.append({"role": "user", "content": txt_input})
229
+ completion = CLIENT.chat.completions.create(model=NIM_MODEL,
230
+ messages=messages,
231
+ temperature=0, top_p=1,
232
+ max_tokens=1024, stream=True)
233
+ out_str = ""
234
+ for chunk in completion:
235
+ if chunk.choices[0].delta.content is not None:
236
+ out_str += chunk.choices[0].delta.content
237
+ return out_str
238
+
239
+
240
+ def _parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]:
241
+ # use pythonic checks for triples
242
+ processed = []
243
+ split_by_newline = triples_str.split("\n")
244
+ # sometimes LLM fails to obey the prompt
245
+ if len(split_by_newline) > 1:
246
+ split_triples = split_by_newline
247
+ llm_obeyed = True
248
+ else:
249
+ # handles form "(e, r, e) (e, r, e) ... (e, r, e)""
250
+ split_triples = triples_str[1:-1].split(") (")
251
+ llm_obeyed = False
252
+ for triple_str in split_triples:
253
+ try:
254
+ if llm_obeyed:
255
+ # remove parenthesis and single quotes for parsing
256
+ triple_str = triple_str.replace("(", "").replace(")",
257
+ "").replace(
258
+ "'", "")
259
+ split_trip = triple_str.split(',')
260
+ # remove blank space at beginning or end
261
+ split_trip = [(i[1:] if i[0] == " " else i) for i in split_trip]
262
+ split_trip = [(i[:-1].lower() if i[-1] == " " else i)
263
+ for i in split_trip]
264
+ potential_trip = tuple(split_trip)
265
+ except: # noqa
266
+ continue
267
+ if 'tuple' in str(type(potential_trip)) and len(
268
+ potential_trip
269
+ ) == 3 and "note:" not in potential_trip[0].lower():
270
+ # additional check for empty node/edge attrs
271
+ if potential_trip[0] != '' and potential_trip[
272
+ 1] != '' and potential_trip[2] != '':
273
+ processed.append(potential_trip)
274
+ return processed
275
+
276
+
277
+ def _llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs):
278
+ relevant_triples = []
279
+ for chunk in chunks:
280
+ relevant_triples += py_fn(llm_fn(chunk, **kwargs))
281
+ return relevant_triples
282
+
283
+
284
+ def _multiproc_helper(rank, in_chunks_per_proc, py_fn, llm_fn, NIM_KEY,
285
+ NIM_MODEL, ENDPOINT_URL):
286
+ out = _llm_then_python_parse(in_chunks_per_proc[rank], py_fn, llm_fn,
287
+ GLOBAL_NIM_KEY=NIM_KEY, NIM_MODEL=NIM_MODEL,
288
+ ENDPOINT_URL=ENDPOINT_URL)
289
+ torch.save(out, "/tmp/outs_for_proc_" + str(rank))
290
+
291
+
292
+ def _get_num_procs():
293
+ if hasattr(os, "sched_getaffinity"):
294
+ try:
295
+ num_proc = len(os.sched_getaffinity(0)) / (2)
296
+ except Exception:
297
+ pass
298
+ if num_proc is None:
299
+ num_proc = os.cpu_count() / (2)
300
+ return int(num_proc)
301
+
302
+
303
+ def _chunk_text(text: str, chunk_size: int = 512) -> list[str]:
304
+ """Function to chunk text into sentence-based segments.
305
+ Co-authored with Claude AI.
306
+ """
307
+ # If the input text is empty or None, return an empty list
308
+ if not text:
309
+ return []
310
+
311
+ # List of punctuation marks that typically end sentences
312
+ sentence_endings = '.!?'
313
+
314
+ # List to store the resulting chunks
315
+ chunks = []
316
+
317
+ # Continue processing the entire text
318
+ while text:
319
+ # If the remaining text is shorter than chunk_size, add it and break
320
+ if len(text) <= chunk_size:
321
+ chunks.append(text.strip())
322
+ break
323
+
324
+ # Start with the maximum possible chunk
325
+ chunk = text[:chunk_size]
326
+
327
+ # Try to find the last sentence ending within the chunk
328
+ best_split = chunk_size
329
+ for ending in sentence_endings:
330
+ # Find the last occurrence of the ending punctuation
331
+ last_ending = chunk.rfind(ending)
332
+ if last_ending != -1:
333
+ # Ensure we include the punctuation and any following space
334
+ best_split = min(
335
+ best_split, last_ending + 1 +
336
+ (1 if last_ending + 1 < len(chunk)
337
+ and chunk[last_ending + 1].isspace() else 0))
338
+
339
+ # Adjust to ensure we don't break words
340
+ # If the next character is a letter, find the last space
341
+ if best_split < len(text) and text[best_split].isalpha():
342
+ # Find the last space before the current split point
343
+ space_split = text[:best_split].rfind(' ')
344
+ if space_split != -1:
345
+ best_split = space_split
346
+
347
+ # Append the chunk, ensuring it's stripped
348
+ chunks.append(text[:best_split].strip())
349
+
350
+ # Remove the processed part from the text
351
+ text = text[best_split:].lstrip()
352
+
353
+ return chunks
@@ -0,0 +1,38 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class VisionTransformer(torch.nn.Module):
8
+ r"""A wrapper around a Vision-Transformer from HuggingFace.
9
+
10
+ Args:
11
+ model_name (str): The HuggingFace model name, *e.g.*, :obj:`"ViT"`.
12
+ """
13
+ def __init__(
14
+ self,
15
+ model_name: str,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.model_name = model_name
19
+
20
+ from transformers import SwinConfig, SwinModel
21
+
22
+ self.config = SwinConfig.from_pretrained(model_name)
23
+ self.model = SwinModel(self.config)
24
+
25
+ @torch.no_grad()
26
+ def forward(
27
+ self,
28
+ images: Tensor,
29
+ output_device: Optional[Union[torch.device, str]] = None,
30
+ ) -> Tensor:
31
+ return self.model(images).last_hidden_state.to(output_device)
32
+
33
+ @property
34
+ def device(self) -> torch.device:
35
+ return next(iter(self.model.parameters())).device
36
+
37
+ def __repr__(self) -> str:
38
+ return f'{self.__class__.__name__}(model_name={self.model_name})'