pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .sentence_transformer import SentenceTransformer
|
|
2
|
+
from .vision_transformer import VisionTransformer
|
|
3
|
+
from .llm import LLM
|
|
4
|
+
from .txt2kg import TXT2KG
|
|
5
|
+
from .llm_judge import LLMJudge
|
|
6
|
+
from .g_retriever import GRetriever
|
|
7
|
+
from .molecule_gpt import MoleculeGPT
|
|
8
|
+
from .glem import GLEM
|
|
9
|
+
from .protein_mpnn import ProteinMPNN
|
|
10
|
+
from .git_mol import GITMol
|
|
11
|
+
|
|
12
|
+
__all__ = classes = [
|
|
13
|
+
'SentenceTransformer',
|
|
14
|
+
'VisionTransformer',
|
|
15
|
+
'LLM',
|
|
16
|
+
'LLMJudge',
|
|
17
|
+
'TXT2KG',
|
|
18
|
+
'GRetriever',
|
|
19
|
+
'MoleculeGPT',
|
|
20
|
+
'GLEM',
|
|
21
|
+
'ProteinMPNN',
|
|
22
|
+
'GITMol',
|
|
23
|
+
]
|
|
@@ -3,7 +3,7 @@ from typing import List, Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import Tensor
|
|
5
5
|
|
|
6
|
-
from torch_geometric.
|
|
6
|
+
from torch_geometric.llm.models.llm import LLM, MAX_NEW_TOKENS
|
|
7
7
|
from torch_geometric.utils import scatter
|
|
8
8
|
|
|
9
9
|
|
|
@@ -19,17 +19,19 @@ class GRetriever(torch.nn.Module):
|
|
|
19
19
|
:obj:`peft` for training the LLM, see
|
|
20
20
|
`here <https://huggingface.co/docs/peft/en/index>`_ for details.
|
|
21
21
|
(default: :obj:`False`)
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
mlp_out_tokens (int, optional): Number of LLM prefix tokens to
|
|
23
|
+
reserve for GNN output. (default: :obj:`1`)
|
|
24
24
|
|
|
25
25
|
.. warning::
|
|
26
26
|
This module has been tested with the following HuggingFace models
|
|
27
|
+
* :obj:`llm_to_use="meta-llama/Meta-Llama-3.1-8B-Instruct"`
|
|
28
|
+
* :obj:`llm_to_use="Qwen/Qwen3-0.6B"`
|
|
27
29
|
|
|
28
|
-
* :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
|
|
29
|
-
* :obj:`llm_to_use="google/gemma-7b"`
|
|
30
30
|
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
This module should work with any HuggingFace model.
|
|
32
|
+
See other models at `HuggingFace
|
|
33
|
+
Models <https://huggingface.co/models>`_
|
|
34
|
+
and let us know if you
|
|
33
35
|
encounter any issues.
|
|
34
36
|
|
|
35
37
|
.. note::
|
|
@@ -40,14 +42,14 @@ class GRetriever(torch.nn.Module):
|
|
|
40
42
|
def __init__(
|
|
41
43
|
self,
|
|
42
44
|
llm: LLM,
|
|
43
|
-
gnn: torch.nn.Module,
|
|
45
|
+
gnn: torch.nn.Module = None,
|
|
44
46
|
use_lora: bool = False,
|
|
45
|
-
|
|
47
|
+
mlp_out_tokens: int = 1,
|
|
46
48
|
) -> None:
|
|
47
49
|
super().__init__()
|
|
48
50
|
|
|
49
51
|
self.llm = llm
|
|
50
|
-
self.gnn = gnn.to(self.llm.device)
|
|
52
|
+
self.gnn = gnn.to(self.llm.device) if gnn is not None else None
|
|
51
53
|
|
|
52
54
|
self.word_embedding = self.llm.word_embedding
|
|
53
55
|
self.llm_generator = self.llm.llm
|
|
@@ -73,12 +75,18 @@ class GRetriever(torch.nn.Module):
|
|
|
73
75
|
)
|
|
74
76
|
self.llm_generator = get_peft_model(self.llm_generator, config)
|
|
75
77
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
torch.nn.
|
|
80
|
-
|
|
81
|
-
|
|
78
|
+
if self.gnn is not None:
|
|
79
|
+
mlp_out_channels = llm.word_embedding.embedding_dim
|
|
80
|
+
mlp_hidden_channels = self.gnn.out_channels
|
|
81
|
+
self.projector = torch.nn.Sequential(
|
|
82
|
+
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
|
|
83
|
+
torch.nn.Sigmoid(),
|
|
84
|
+
torch.nn.Linear(mlp_hidden_channels,
|
|
85
|
+
mlp_out_channels * mlp_out_tokens),
|
|
86
|
+
torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
|
|
87
|
+
).to(self.llm.device)
|
|
88
|
+
|
|
89
|
+
self.seq_length_stats = []
|
|
82
90
|
|
|
83
91
|
def encode(
|
|
84
92
|
self,
|
|
@@ -93,7 +101,16 @@ class GRetriever(torch.nn.Module):
|
|
|
93
101
|
edge_attr = edge_attr.to(self.llm.device)
|
|
94
102
|
batch = batch.to(self.llm.device)
|
|
95
103
|
|
|
96
|
-
|
|
104
|
+
model_specific_kwargs = {}
|
|
105
|
+
|
|
106
|
+
# duck typing for SGFormer to get around circular import
|
|
107
|
+
if (hasattr(self.gnn, 'trans_conv')
|
|
108
|
+
and hasattr(self.gnn, 'graph_conv')):
|
|
109
|
+
model_specific_kwargs['batch'] = batch
|
|
110
|
+
else:
|
|
111
|
+
model_specific_kwargs['edge_attr'] = edge_attr
|
|
112
|
+
|
|
113
|
+
out = self.gnn(x, edge_index, **model_specific_kwargs)
|
|
97
114
|
return scatter(out, batch, dim=0, reduce='mean')
|
|
98
115
|
|
|
99
116
|
def forward(
|
|
@@ -122,24 +139,32 @@ class GRetriever(torch.nn.Module):
|
|
|
122
139
|
to give to the LLM, such as textified knowledge graphs.
|
|
123
140
|
(default: :obj:`None`)
|
|
124
141
|
"""
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
142
|
+
xs = None
|
|
143
|
+
if self.gnn is not None:
|
|
144
|
+
x = self.encode(x, edge_index, batch, edge_attr)
|
|
145
|
+
x = self.projector(x)
|
|
146
|
+
xs = x.split(1, dim=0)
|
|
147
|
+
|
|
148
|
+
# Handle case where theres more than one embedding for each sample
|
|
149
|
+
xs = [x.squeeze(0) for x in xs]
|
|
150
|
+
|
|
151
|
+
# Handle questions without node features:
|
|
152
|
+
batch_unique = batch.unique()
|
|
153
|
+
batch_size = len(question)
|
|
154
|
+
if len(batch_unique) < batch_size:
|
|
155
|
+
xs = [
|
|
156
|
+
xs[i] if i in batch_unique else None
|
|
157
|
+
for i in range(batch_size)
|
|
158
|
+
]
|
|
137
159
|
(
|
|
138
160
|
inputs_embeds,
|
|
139
161
|
attention_mask,
|
|
140
162
|
label_input_ids,
|
|
141
163
|
) = self.llm._get_embeds(question, additional_text_context, xs, label)
|
|
142
164
|
|
|
165
|
+
max_seq_len = inputs_embeds.size(1)
|
|
166
|
+
self.seq_length_stats.append(max_seq_len)
|
|
167
|
+
|
|
143
168
|
with self.llm.autocast_context:
|
|
144
169
|
outputs = self.llm_generator(
|
|
145
170
|
inputs_embeds=inputs_embeds,
|
|
@@ -178,32 +203,39 @@ class GRetriever(torch.nn.Module):
|
|
|
178
203
|
max_out_tokens (int, optional): How many tokens for the LLM to
|
|
179
204
|
generate. (default: :obj:`32`)
|
|
180
205
|
"""
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
206
|
+
xs = None
|
|
207
|
+
if self.gnn is not None:
|
|
208
|
+
x = self.encode(x, edge_index, batch, edge_attr)
|
|
209
|
+
x = self.projector(x)
|
|
210
|
+
xs = x.split(1, dim=0)
|
|
211
|
+
|
|
212
|
+
# Handle case where theres more than one embedding for each sample
|
|
213
|
+
xs = [x.squeeze(0) for x in xs]
|
|
214
|
+
|
|
215
|
+
# Handle questions without node features:
|
|
216
|
+
batch_unique = batch.unique()
|
|
217
|
+
batch_size = len(question)
|
|
218
|
+
if len(batch_unique) < batch_size:
|
|
219
|
+
xs = [
|
|
220
|
+
xs[i] if i in batch_unique else None
|
|
221
|
+
for i in range(batch_size)
|
|
222
|
+
]
|
|
192
223
|
|
|
193
224
|
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
|
|
194
225
|
question, additional_text_context, xs)
|
|
195
226
|
|
|
196
|
-
bos_token = self.llm.tokenizer(
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
).input_ids[0]
|
|
227
|
+
# bos_token = self.llm.tokenizer(
|
|
228
|
+
# self.llm.tokenizer.bos_token_id,
|
|
229
|
+
# add_special_tokens=False,
|
|
230
|
+
# ).input_ids[0]
|
|
200
231
|
|
|
201
232
|
with self.llm.autocast_context:
|
|
202
233
|
outputs = self.llm_generator.generate(
|
|
203
234
|
inputs_embeds=inputs_embeds,
|
|
204
235
|
max_new_tokens=max_out_tokens,
|
|
205
236
|
attention_mask=attention_mask,
|
|
206
|
-
bos_token_id=
|
|
237
|
+
bos_token_id=self.llm.tokenizer.bos_token_id,
|
|
238
|
+
pad_token_id=self.llm.tokenizer.eos_token_id,
|
|
207
239
|
use_cache=True # Important to set!
|
|
208
240
|
)
|
|
209
241
|
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
|
|
7
|
+
|
|
8
|
+
from torch_geometric.llm.models import SentenceTransformer, VisionTransformer
|
|
9
|
+
from torch_geometric.nn import GINEConv
|
|
10
|
+
from torch_geometric.utils import add_self_loops, to_dense_batch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GraphEncoder(torch.nn.Module):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_layers: int,
|
|
17
|
+
in_channels: int,
|
|
18
|
+
dropout: float = 0.,
|
|
19
|
+
num_atom_type: int = 120,
|
|
20
|
+
num_chirality_tag: int = 3,
|
|
21
|
+
num_bond_type: int = 6,
|
|
22
|
+
num_bond_direction: int = 3,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
self.num_layers = num_layers
|
|
27
|
+
self.dropout = dropout
|
|
28
|
+
|
|
29
|
+
self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
|
|
30
|
+
self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
|
|
31
|
+
self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
|
|
32
|
+
self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
|
|
33
|
+
|
|
34
|
+
self.gnns = torch.nn.ModuleList()
|
|
35
|
+
self.batch_norms = torch.nn.ModuleList()
|
|
36
|
+
for _ in range(num_layers):
|
|
37
|
+
self.gnns.append(
|
|
38
|
+
GINEConv(
|
|
39
|
+
nn=Sequential(
|
|
40
|
+
Linear(in_channels, in_channels * 2),
|
|
41
|
+
ReLU(),
|
|
42
|
+
Linear(in_channels * 2, in_channels),
|
|
43
|
+
),
|
|
44
|
+
train_eps=True,
|
|
45
|
+
edge_dim=in_channels,
|
|
46
|
+
))
|
|
47
|
+
self.batch_norms.append(BatchNorm1d(in_channels))
|
|
48
|
+
self.reset_parameters()
|
|
49
|
+
|
|
50
|
+
def reset_parameters(self):
|
|
51
|
+
torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
|
|
52
|
+
torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
|
|
53
|
+
torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
|
|
54
|
+
torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
|
|
55
|
+
|
|
56
|
+
def forward(
|
|
57
|
+
self,
|
|
58
|
+
x: Tensor,
|
|
59
|
+
edge_index: Tensor,
|
|
60
|
+
batch: Tensor,
|
|
61
|
+
edge_attr: Tensor,
|
|
62
|
+
) -> Tensor:
|
|
63
|
+
x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
|
|
64
|
+
edge_index, edge_attr = add_self_loops(
|
|
65
|
+
edge_index,
|
|
66
|
+
edge_attr,
|
|
67
|
+
fill_value=0,
|
|
68
|
+
num_nodes=x.size(0),
|
|
69
|
+
)
|
|
70
|
+
edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
|
|
71
|
+
edge_attr[:, 1])
|
|
72
|
+
for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
|
|
73
|
+
x = gnn(x, edge_index, edge_attr)
|
|
74
|
+
x = bn(x)
|
|
75
|
+
if i < self.num_layers - 1:
|
|
76
|
+
x = F.relu(x)
|
|
77
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
78
|
+
|
|
79
|
+
x, mask = to_dense_batch(x, batch)
|
|
80
|
+
return x, mask
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class GITFormer(torch.nn.Module):
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
num_query_token: int,
|
|
87
|
+
vision_graph_width: int,
|
|
88
|
+
cross_attention_freq: int = 2,
|
|
89
|
+
):
|
|
90
|
+
super().__init__()
|
|
91
|
+
from transformers import AutoConfig, AutoModel
|
|
92
|
+
|
|
93
|
+
config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
|
|
94
|
+
config.encoder_width = vision_graph_width
|
|
95
|
+
# insert cross-attention layer every other block
|
|
96
|
+
config.add_cross_attention = True
|
|
97
|
+
config.is_decoder = True
|
|
98
|
+
config.cross_attention_freq = cross_attention_freq
|
|
99
|
+
config.query_length = num_query_token
|
|
100
|
+
self.Qformer = AutoModel.from_pretrained(
|
|
101
|
+
"allenai/scibert_scivocab_uncased", config=config)
|
|
102
|
+
self.query_tokens = torch.nn.Parameter(
|
|
103
|
+
torch.zeros(1, num_query_token, config.hidden_size))
|
|
104
|
+
self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class GITMol(torch.nn.Module):
|
|
108
|
+
r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
|
|
109
|
+
Model for Molecular Science with Graph, Image, and Text"
|
|
110
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
|
111
|
+
|
|
112
|
+
.. note::
|
|
113
|
+
For an example of using :class:`GITMol`, see
|
|
114
|
+
`examples/llm/git_mol.py <https://github.com/pyg-team/
|
|
115
|
+
pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
|
|
116
|
+
"""
|
|
117
|
+
def __init__(self) -> None:
|
|
118
|
+
super().__init__()
|
|
119
|
+
# graph
|
|
120
|
+
self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
|
|
121
|
+
self.graph_proj = Linear(16, 768)
|
|
122
|
+
self.ln_graph = LayerNorm(768)
|
|
123
|
+
# text
|
|
124
|
+
self.text_encoder = SentenceTransformer(
|
|
125
|
+
model_name='allenai/scibert_scivocab_uncased',
|
|
126
|
+
pooling_strategy='last_hidden_state',
|
|
127
|
+
)
|
|
128
|
+
self.text_proj = Linear(768, 768)
|
|
129
|
+
self.ln_text = LayerNorm(768)
|
|
130
|
+
# vision
|
|
131
|
+
self.vision_encoder = VisionTransformer(
|
|
132
|
+
model_name='microsoft/swin-base-patch4-window7-224', )
|
|
133
|
+
self.vision_proj = Linear(1024, 768)
|
|
134
|
+
self.ln_vision = LayerNorm(768)
|
|
135
|
+
# cross-attention
|
|
136
|
+
self.gitformer = GITFormer(384, 768)
|
|
137
|
+
|
|
138
|
+
self.xtm_head = torch.nn.ModuleDict({
|
|
139
|
+
'image':
|
|
140
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
|
141
|
+
'graph':
|
|
142
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
|
143
|
+
'cs_text':
|
|
144
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
|
145
|
+
})
|
|
146
|
+
|
|
147
|
+
self.xtc_proj = torch.nn.ModuleDict({
|
|
148
|
+
'image':
|
|
149
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
|
150
|
+
'graph':
|
|
151
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
|
152
|
+
'cs_text':
|
|
153
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
|
154
|
+
})
|
|
155
|
+
self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
|
|
156
|
+
self.model_freeze()
|
|
157
|
+
|
|
158
|
+
def model_freeze(self) -> None:
|
|
159
|
+
for param in self.graph_encoder.parameters():
|
|
160
|
+
param.requires_grad = False
|
|
161
|
+
|
|
162
|
+
for param in self.vision_encoder.parameters():
|
|
163
|
+
param.requires_grad = False
|
|
164
|
+
|
|
165
|
+
def forward(
|
|
166
|
+
self,
|
|
167
|
+
x: Tensor,
|
|
168
|
+
edge_index: Tensor,
|
|
169
|
+
batch: Tensor,
|
|
170
|
+
edge_attr: Optional[Tensor],
|
|
171
|
+
smiles: List[str],
|
|
172
|
+
images: Tensor,
|
|
173
|
+
captions: List[str],
|
|
174
|
+
) -> Tensor:
|
|
175
|
+
batch_size = len(smiles)
|
|
176
|
+
|
|
177
|
+
x_vision = self.vision_encoder(images)
|
|
178
|
+
x_vision = self.vision_proj(x_vision)
|
|
179
|
+
x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
|
|
180
|
+
vision_atts = torch.ones(x_vision.size()[:-1],
|
|
181
|
+
dtype=torch.long).to(x_vision.device)
|
|
182
|
+
vision_targets = torch.arange(batch_size).to(x_vision.device)
|
|
183
|
+
|
|
184
|
+
x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
|
|
185
|
+
edge_attr)
|
|
186
|
+
x_graph = self.graph_proj(x_graph)
|
|
187
|
+
x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
|
|
188
|
+
graph_targets = torch.arange(batch_size).to(x_graph.device)
|
|
189
|
+
|
|
190
|
+
x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
|
|
191
|
+
smiles_atts = torch.ones(x_smiles.size()[:-1],
|
|
192
|
+
dtype=torch.long).to(x_smiles.device)
|
|
193
|
+
smiles_targets = torch.arange(batch_size).to(x_smiles.device)
|
|
194
|
+
|
|
195
|
+
caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
|
|
196
|
+
captions)
|
|
197
|
+
|
|
198
|
+
text_output = self.gitformer.Qformer(
|
|
199
|
+
caption_input_ids,
|
|
200
|
+
attention_mask=caption_attention_masks,
|
|
201
|
+
return_dict=True,
|
|
202
|
+
)
|
|
203
|
+
text_feat = F.normalize(
|
|
204
|
+
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
|
|
205
|
+
|
|
206
|
+
loss = 0
|
|
207
|
+
for x_embed, x_atts, x_targets, modal in zip(
|
|
208
|
+
[x_graph, x_smiles, x_vision],
|
|
209
|
+
[graph_atts, smiles_atts, vision_atts],
|
|
210
|
+
[graph_targets, smiles_targets, vision_targets],
|
|
211
|
+
['graph', 'cs_text', 'image'],
|
|
212
|
+
):
|
|
213
|
+
loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
|
|
214
|
+
modal)
|
|
215
|
+
loss += self._calc_xtm_loss(x_embed, caption_input_ids,
|
|
216
|
+
caption_attention_masks, modal)
|
|
217
|
+
|
|
218
|
+
return loss / 6
|
|
219
|
+
|
|
220
|
+
def _calc_xtm_loss(
|
|
221
|
+
self,
|
|
222
|
+
x_embeds: Tensor,
|
|
223
|
+
input_ids: Tensor,
|
|
224
|
+
attention_mask: Tensor,
|
|
225
|
+
modal: str,
|
|
226
|
+
) -> Tensor:
|
|
227
|
+
# Initializing lists to hold the original and negative samples
|
|
228
|
+
x_embeds_list = []
|
|
229
|
+
text_input_ids_list = []
|
|
230
|
+
text_attention_mask_list = []
|
|
231
|
+
|
|
232
|
+
batch_size = x_embeds.size(0)
|
|
233
|
+
for i in range(batch_size):
|
|
234
|
+
# Original samples
|
|
235
|
+
x_embeds_list.append(x_embeds[i])
|
|
236
|
+
text_input_ids_list.append(input_ids[i, :])
|
|
237
|
+
text_attention_mask_list.append(attention_mask[i, :])
|
|
238
|
+
|
|
239
|
+
if batch_size > 1:
|
|
240
|
+
# Negative samples (neg_text_input_ids corresponds to x_embeds)
|
|
241
|
+
neg_text_input_ids = input_ids[i - 1 if i == batch_size -
|
|
242
|
+
1 else i + 1, :]
|
|
243
|
+
neg_text_attention_mask = attention_mask[i -
|
|
244
|
+
1 if i == batch_size -
|
|
245
|
+
1 else i + 1, :]
|
|
246
|
+
text_input_ids_list.append(neg_text_input_ids)
|
|
247
|
+
text_attention_mask_list.append(neg_text_attention_mask)
|
|
248
|
+
x_embeds_list.append(x_embeds[i, :])
|
|
249
|
+
|
|
250
|
+
# Negative samples (text_input_ids corresponds to neg_x_embeds)
|
|
251
|
+
neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
|
|
252
|
+
1, :]
|
|
253
|
+
x_embeds_list.append(neg_x_embeds)
|
|
254
|
+
text_input_ids_list.append(input_ids[i, :])
|
|
255
|
+
text_attention_mask_list.append(attention_mask[i, :])
|
|
256
|
+
|
|
257
|
+
# Stack all samples into two large tensors
|
|
258
|
+
x_embeds_all = torch.stack(x_embeds_list, dim=1) \
|
|
259
|
+
.reshape(-1, x_embeds.size(1), x_embeds.size(2))
|
|
260
|
+
text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
|
|
261
|
+
.reshape(-1, input_ids.size(1))
|
|
262
|
+
# Create image attention masks for the concatenated tensor
|
|
263
|
+
image_attns_all = torch.ones(x_embeds_all.size()[:-1],
|
|
264
|
+
dtype=torch.long).to(x_embeds_all.device)
|
|
265
|
+
query_tokens_xtm = self.gitformer.query_tokens.expand(
|
|
266
|
+
text_input_ids_all.shape[0], -1, -1)
|
|
267
|
+
query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
|
|
268
|
+
dtype=torch.long).to(x_embeds_all.device)
|
|
269
|
+
|
|
270
|
+
output_xtm = self.gitformer.Qformer(
|
|
271
|
+
inputs_embeds=query_tokens_xtm,
|
|
272
|
+
attention_mask=query_attns_xtm,
|
|
273
|
+
encoder_hidden_states=x_embeds_all,
|
|
274
|
+
encoder_attention_mask=image_attns_all,
|
|
275
|
+
return_dict=True,
|
|
276
|
+
).last_hidden_state
|
|
277
|
+
|
|
278
|
+
xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
|
|
279
|
+
|
|
280
|
+
xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
|
|
281
|
+
# Create labels: 1 for the original samples, 0 for the negative samples
|
|
282
|
+
if batch_size > 1:
|
|
283
|
+
labels = torch.cat(
|
|
284
|
+
[torch.ones(batch_size),
|
|
285
|
+
torch.zeros(batch_size * 2)], dim=0)
|
|
286
|
+
else:
|
|
287
|
+
labels = torch.ones(batch_size)
|
|
288
|
+
labels = labels.long().to(xtm_logit.device)
|
|
289
|
+
|
|
290
|
+
# Calculate cross entropy loss
|
|
291
|
+
return F.cross_entropy(xtm_logit, labels)
|
|
292
|
+
|
|
293
|
+
def _calc_xtc_loss(
|
|
294
|
+
self,
|
|
295
|
+
x_embeds: Tensor,
|
|
296
|
+
x_atts: Tensor,
|
|
297
|
+
x_targets: Tensor,
|
|
298
|
+
text_feat: Tensor,
|
|
299
|
+
modal: str,
|
|
300
|
+
) -> Tensor:
|
|
301
|
+
query_tokens = self.gitformer.query_tokens.expand(
|
|
302
|
+
x_embeds.shape[0], -1, -1)
|
|
303
|
+
|
|
304
|
+
query_output = self.gitformer.Qformer(
|
|
305
|
+
inputs_embeds=query_tokens,
|
|
306
|
+
encoder_hidden_states=x_embeds,
|
|
307
|
+
encoder_attention_mask=x_atts,
|
|
308
|
+
return_dict=True,
|
|
309
|
+
).last_hidden_state
|
|
310
|
+
|
|
311
|
+
x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
|
|
312
|
+
|
|
313
|
+
sim_q2t = torch.matmul(
|
|
314
|
+
x_feats.unsqueeze(1),
|
|
315
|
+
text_feat.unsqueeze(-1),
|
|
316
|
+
).squeeze(-1)
|
|
317
|
+
|
|
318
|
+
# modal-text similarity: aggregate across all query tokens
|
|
319
|
+
sim_x2t, _ = sim_q2t.max(-1)
|
|
320
|
+
sim_x2t = sim_x2t / self.temp
|
|
321
|
+
|
|
322
|
+
# text-query similarity
|
|
323
|
+
sim_t2q = torch.matmul(
|
|
324
|
+
text_feat.unsqueeze(1).unsqueeze(1),
|
|
325
|
+
x_feats.permute(0, 2, 1),
|
|
326
|
+
).squeeze(-2)
|
|
327
|
+
|
|
328
|
+
# text-modal similarity: aggregate across all query tokens
|
|
329
|
+
sim_t2x, _ = sim_t2q.max(-1)
|
|
330
|
+
sim_t2x = sim_t2x / self.temp
|
|
331
|
+
|
|
332
|
+
loss_itc = (
|
|
333
|
+
F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
|
|
334
|
+
F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
|
|
335
|
+
|
|
336
|
+
return loss_itc
|