pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -0,0 +1,158 @@
1
+ from math import isnan
2
+ from typing import Optional
3
+
4
+ from torch_geometric.llm.models.txt2kg import \
5
+ _chunk_to_triples_str_cloud as call_NIM
6
+
7
+ # Credit for original "Marlin Accuracy" system goes to:
8
+ # Gilberto Titericz (NVIDIA)
9
+ # This work is an adaptation of his for PyG
10
+ SYSTEM_PROMPT_1 = (
11
+ "Instruction: You are a world class state of the art " +
12
+ "assistant for rating " +
13
+ "a User Answer given a Question. The Question is completely" +
14
+ " answered by the Reference Answer.\n" +
15
+ "Say 4, if User Answer is full contained and equivalent to" +
16
+ " Reference Answer" +
17
+ "in all terms, topics, numbers, metrics, dates and units.\n" +
18
+ "Say 2, if User Answer is partially contained and almost " +
19
+ "equivalent to Reference Answer" +
20
+ "in all terms, topics, numbers, metrics, dates and units.\n" +
21
+ "Say 0, if User Answer is not contained in Reference Answer" +
22
+ " or not accurate in all terms, topics," +
23
+ "numbers, metrics, dates and units or the User Answer do not" +
24
+ " answer the question.\n" +
25
+ "Do not explain or justify your rating. Your rating must be " +
26
+ "only 4, 2 or 0 according to the instructions above.\n" +
27
+ "### Question: \"{question}\"\n" + "### User Answer: \"{model_pred}\"\n" +
28
+ "### Reference Answer: \"{correct_answer}\"\n" + "The rating is:\n")
29
+
30
+ SYSTEM_PROMPT_2 = (
31
+ "I will rate the User Answer in comparison to the Reference " +
32
+ "Answer for a given Question.\n" +
33
+ "A rating of 4 indicates that the User Answer is entirely " +
34
+ "consistent with the Reference Answer, covering all aspects," +
35
+ " topics, numbers, metrics, dates, and units.\n" +
36
+ "A rating of 2 signifies that the User Answer is mostly " +
37
+ "aligned with the Reference Answer, with minor discrepancies" +
38
+ " in some areas.\n" +
39
+ "A rating of 0 means that the User Answer is either " +
40
+ "inaccurate, incomplete, or unrelated to the Reference " +
41
+ "Answer, or it fails to address the Question.\n" +
42
+ "I will provide the rating without any explanation or " +
43
+ "justification, adhering to the following scale: " +
44
+ "0 (no match), 2 (partial match), 4 (exact match).\n" +
45
+ "Do not explain or justify my rating. My rating must" +
46
+ " be only 4, 2 or 0 only.\n\n" + "Question: \"{question}\"\n\n" +
47
+ "Reference Answer: \"{model_pred}\"\n\n" +
48
+ "User Answer: \"{correct_answer}\"\n\n" + "Rating: ")
49
+
50
+
51
+ # TODO: add support for Local LM
52
+ # TODO: add multiproc support like txt2kg
53
+ class LLMJudge():
54
+ """Uses NIMs to score a triple of (question, model_pred, correct_answer)
55
+ This whole class is an adaptation of Gilberto's work for PyG.
56
+
57
+ Args:
58
+ NVIDIA_NIM_MODEL : (str, optional)
59
+ The name of the NVIDIA NIM model to use.
60
+ (default: "nvidia/llama-3.1-nemotron-70b-instruct").
61
+ NVIDIA_API_KEY : (str, optional)
62
+ The API key for accessing NVIDIA's NIM models.
63
+ (default: "").
64
+ ENDPOINT_URL : (str, optional)
65
+ The URL hosting your model, in case you are not using
66
+ the public NIM.
67
+ (default: "https://integrate.api.nvidia.com/v1").
68
+ """
69
+ def __init__(
70
+ self,
71
+ NVIDIA_NIM_MODEL: Optional[
72
+ str] = "nvidia/llama-3.1-nemotron-70b-instruct",
73
+ NVIDIA_API_KEY: Optional[str] = "",
74
+ ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
75
+ ) -> None:
76
+ self.NVIDIA_API_KEY = NVIDIA_API_KEY
77
+ self.NIM_MODEL = NVIDIA_NIM_MODEL
78
+ self.ENDPOINT_URL = ENDPOINT_URL
79
+
80
+ def _process_score(self, response: str) -> float:
81
+ """Uses 3 and 1 even though prompt says only 0, 2, 4.
82
+ This is because LLMs don't always follow instructions.
83
+ Credit to Gilberto.
84
+ """
85
+ for i in [4, 3, 2, 1, 0]:
86
+ if str(i) in response:
87
+ return i / 4
88
+ return float("nan")
89
+
90
+ def _average_scores(self, score0: float, score1: float):
91
+ """Take the average of score0 and score1.
92
+ Sometimes the LLM fail to respond or have no score in the response.
93
+ In those cases the failed score is discarded.
94
+ Credit to Gilberto.
95
+
96
+ Args:
97
+ score0 (float): judge accuracy score.
98
+ score1 (float): judge accuracy score by permuting agent answer and
99
+ ground truth.
100
+
101
+ Returns:
102
+ (float) average of score0 and score1 of both contains scores,
103
+ otherwise pick the max.
104
+ """
105
+ score = float("nan")
106
+ if score0 >= 0 and score1 >= 0:
107
+ score = (score0 + score1) / 2
108
+ else:
109
+ score = max(score0, score1)
110
+ return score
111
+
112
+ def score(
113
+ self,
114
+ question: str,
115
+ model_pred: str,
116
+ correct_answer: str,
117
+ ) -> float:
118
+ """Args:
119
+ question (str): The original question asked to the model.
120
+ model_pred (str): The prediction made by the model.
121
+ correct_answer (str): The actual correct answer to the question.
122
+
123
+ Returns:
124
+ score (float): score of 0-1, may be nan due to LLM judge failure.
125
+ Evals should skip nan's when aggregating score.
126
+ """
127
+ prompt1 = SYSTEM_PROMPT_1.format(question=question,
128
+ model_pred=model_pred,
129
+ correct_answer=correct_answer)
130
+ prompt2 = SYSTEM_PROMPT_2.format(question=question,
131
+ model_pred=model_pred,
132
+ correct_answer=correct_answer)
133
+ score1 = float("nan")
134
+ score2 = float("nan")
135
+ for _retry in range(200):
136
+ try:
137
+ score1 = self._process_score(
138
+ call_NIM(prompt1, self.NVIDIA_API_KEY, self.NIM_MODEL,
139
+ self.ENDPOINT_URL, post_text=""))
140
+ if not isnan(score1):
141
+ break
142
+ except ImportError:
143
+ raise
144
+ except: # noqa
145
+ pass
146
+ for _retry in range(20):
147
+ try:
148
+ score2 = self._process_score(
149
+ call_NIM(prompt2, self.NVIDIA_API_KEY, self.NIM_MODEL,
150
+ self.ENDPOINT_URL, post_text=""))
151
+ if not isnan(score2):
152
+ break
153
+ except ImportError:
154
+ raise
155
+ except: # noqa
156
+ pass
157
+
158
+ return self._average_scores(score1, score2)
@@ -0,0 +1,222 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.llm.models.llm import BOS, LLM, MAX_NEW_TOKENS
7
+ from torch_geometric.nn.attention import QFormer
8
+ from torch_geometric.utils import to_dense_batch
9
+
10
+
11
+ def pad_or_truncate(embeddings: Tensor, max_seq_len: int,
12
+ padding_value: int = 0) -> Tensor:
13
+ batch_size, current_seq_len, d = embeddings.size()
14
+
15
+ if current_seq_len > max_seq_len:
16
+ return embeddings[:, :max_seq_len, :]
17
+ elif current_seq_len < max_seq_len:
18
+ pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d),
19
+ padding_value, dtype=embeddings.dtype,
20
+ device=embeddings.device)
21
+ return torch.cat([embeddings, pad_tensor], dim=1)
22
+ else:
23
+ return embeddings
24
+
25
+
26
+ class MoleculeGPT(torch.nn.Module):
27
+ r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction
28
+ Following Large Language Models for Molecular Property Prediction"
29
+ <https://ai4d3.github.io/papers/34.pdf>`_ paper.
30
+
31
+ Args:
32
+ llm (LLM): The LLM to use.
33
+ graph_encoder (torch.nn.Module): Encode 2D molecule graph.
34
+ smiles_encoder (torch.nn.Module): Encode 1D SMILES.
35
+ mlp_out_channels (int, optional): The size of each embedding
36
+ after qformer encoding. (default: :obj:`32`)
37
+ max_tokens (int, optional): Max output tokens of 1D/2D encoder.
38
+ (default: :obj:`20`)
39
+
40
+ .. warning::
41
+ This module has been tested with the following HuggingFace models
42
+
43
+ * :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"`
44
+
45
+ and may not work with other models. See other models at `HuggingFace
46
+ Models <https://huggingface.co/models>`_ and let us know if you
47
+ encounter any issues.
48
+
49
+ .. note::
50
+ For an example of using :class:`MoleculeGPT`, see
51
+ `examples/llm/molecule_gpt.py <https://github.com/pyg-team/
52
+ pytorch_geometric/blob/master/examples/llm/molecule_gpt.py>`_.
53
+ """
54
+ def __init__(
55
+ self,
56
+ llm: LLM,
57
+ graph_encoder: torch.nn.Module,
58
+ smiles_encoder: torch.nn.Module,
59
+ mlp_out_channels: int = 32,
60
+ max_tokens: Optional[int] = 20,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.llm = llm
64
+ self.graph_encoder = graph_encoder.to(self.llm.device)
65
+ self.smiles_encoder = smiles_encoder.to(self.llm.device)
66
+
67
+ self.graph_qformer = QFormer(
68
+ input_dim=self.graph_encoder.nn[-1].out_features,
69
+ hidden_dim=mlp_out_channels,
70
+ output_dim=mlp_out_channels,
71
+ num_heads=4,
72
+ num_layers=2,
73
+ ).to(self.llm.device)
74
+
75
+ self.smiles_qformer = QFormer(
76
+ input_dim=self.smiles_encoder.model.pooler.dense.out_features,
77
+ hidden_dim=mlp_out_channels,
78
+ output_dim=mlp_out_channels,
79
+ num_heads=4,
80
+ num_layers=2,
81
+ ).to(self.llm.device)
82
+
83
+ self.max_tokens = max_tokens
84
+
85
+ self.word_embedding = self.llm.word_embedding
86
+ self.llm_generator = self.llm.llm
87
+
88
+ # LLMs
89
+ in_dim = 2 * mlp_out_channels * max_tokens
90
+ out_dim = self.llm.llm.model.embed_tokens.embedding_dim
91
+ self.projector = torch.nn.Sequential(
92
+ torch.nn.Linear(in_dim, in_dim),
93
+ torch.nn.Sigmoid(),
94
+ torch.nn.Linear(in_dim, out_dim),
95
+ ).to(self.llm.device)
96
+
97
+ def encode(
98
+ self,
99
+ x: Tensor,
100
+ edge_index: Tensor,
101
+ batch: Tensor,
102
+ edge_attr: Optional[Tensor],
103
+ smiles: List[str],
104
+ ) -> Tensor:
105
+ batch_size = len(smiles)
106
+ # 2D Graph Branch: [bs, node_len, d]
107
+ x = x.to(self.llm.device)
108
+ edge_index = edge_index.to(self.llm.device)
109
+ if edge_attr is not None:
110
+ edge_attr = edge_attr.to(self.llm.device)
111
+ batch = batch.to(self.llm.device)
112
+
113
+ x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr)
114
+ x_graph = to_dense_batch(x_graph, batch)[0]
115
+ out_graph = self.graph_qformer(x_graph)
116
+ out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens,
117
+ padding_value=0)
118
+ out_graph = out_graph.view(batch_size, -1)
119
+
120
+ # 1D SMILES Branch: [bs, seq_len, d]
121
+ x_smiles = self.smiles_encoder.encode(smiles,
122
+ output_device=self.llm.device)
123
+ out_smiles = self.smiles_qformer(x_smiles)
124
+ out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens,
125
+ padding_value=0)
126
+ out_smiles = out_smiles.view(batch_size, -1)
127
+
128
+ # Merge into LLMs
129
+ x_cat = torch.cat([out_graph, out_smiles], dim=1)
130
+ return x_cat
131
+
132
+ def forward(
133
+ self,
134
+ x: Tensor,
135
+ edge_index: Tensor,
136
+ batch: Tensor,
137
+ edge_attr: Optional[Tensor],
138
+ smiles: List[str],
139
+ instructions: List[str],
140
+ label: List[str],
141
+ additional_text_context: Optional[List[str]] = None,
142
+ ):
143
+ x = self.encode(x, edge_index, batch, edge_attr, smiles)
144
+ x = self.projector(x)
145
+ xs = x.split(1, dim=0)
146
+
147
+ batch_unique = batch.unique()
148
+ batch_size = len(instructions)
149
+ if len(batch_unique) < batch_size:
150
+ xs = [
151
+ xs[i] if i in batch_unique else None for i in range(batch_size)
152
+ ]
153
+
154
+ (
155
+ inputs_embeds,
156
+ attention_mask,
157
+ label_input_ids,
158
+ ) = self.llm._get_embeds(instructions, additional_text_context, xs,
159
+ label)
160
+
161
+ with self.llm.autocast_context:
162
+ outputs = self.llm_generator(
163
+ inputs_embeds=inputs_embeds,
164
+ attention_mask=attention_mask,
165
+ return_dict=True,
166
+ labels=label_input_ids,
167
+ )
168
+
169
+ return outputs.loss
170
+
171
+ @torch.no_grad()
172
+ def inference(
173
+ self,
174
+ x: Tensor,
175
+ edge_index: Tensor,
176
+ batch: Tensor,
177
+ edge_attr: Optional[Tensor],
178
+ smiles: List[str],
179
+ instructions: List[str],
180
+ additional_text_context: Optional[List[str]] = None,
181
+ max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
182
+ ):
183
+ x = self.encode(x, edge_index, batch, edge_attr, smiles)
184
+ x = self.projector(x)
185
+ xs = x.split(1, dim=0)
186
+
187
+ # Handle questions without node features:
188
+ batch_unique = batch.unique()
189
+ batch_size = len(instructions)
190
+ if len(batch_unique) < batch_size:
191
+ xs = [
192
+ xs[i] if i in batch_unique else None for i in range(batch_size)
193
+ ]
194
+
195
+ inputs_embeds, attention_mask, _ = self.llm._get_embeds(
196
+ instructions, additional_text_context, xs)
197
+
198
+ bos_token = self.llm.tokenizer(
199
+ BOS,
200
+ add_special_tokens=False,
201
+ ).input_ids[0]
202
+
203
+ with self.llm.autocast_context:
204
+ outputs = self.llm_generator.generate(
205
+ inputs_embeds=inputs_embeds,
206
+ max_new_tokens=max_out_tokens,
207
+ attention_mask=attention_mask,
208
+ bos_token_id=bos_token,
209
+ use_cache=True # Important to set!
210
+ )
211
+
212
+ return self.llm.tokenizer.batch_decode(
213
+ outputs,
214
+ skip_special_tokens=True,
215
+ )
216
+
217
+ def __repr__(self) -> str:
218
+ return (f'{self.__class__.__name__}(\n'
219
+ f' llm={self.llm},\n'
220
+ f' graph={self.graph_encoder.__class__.__name__},\n'
221
+ f' smiles={self.smiles_encoder},\n'
222
+ f')')