pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -1,283 +0,0 @@
1
- import warnings
2
- from contextlib import nullcontext
3
- from typing import Any, Dict, List, Optional
4
-
5
- import torch
6
- from torch import Tensor
7
-
8
- BOS = '<s>[INST]'
9
- EOS_USER = '[/INST]'
10
- EOS = '[/s]'
11
- IGNORE_INDEX = -100
12
- MAX_TXT_LEN = 512
13
- MAX_NEW_TOKENS = 32
14
- PAD_TOKEN_ID = 0
15
- PADDING_SIDE = 'left'
16
-
17
-
18
- def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
19
- torch.cuda.empty_cache()
20
-
21
- gpu_memory: List[int] = []
22
- for i in range(torch.cuda.device_count()):
23
- gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
24
- # Use the minimum number of GPUs to fit the LLM on.
25
- if sum(gpu_memory) >= required_memory:
26
- break
27
-
28
- if sum(gpu_memory) < required_memory:
29
- gpu_memory = [] # If not enough VRAM, use pure CPU.
30
-
31
- kwargs = dict(revision='main')
32
- if len(gpu_memory) > 0:
33
- kwargs['max_memory'] = {
34
- i: f'{memory}GiB'
35
- for i, memory in enumerate(gpu_memory)
36
- }
37
- kwargs['low_cpu_mem_usage'] = True
38
- kwargs['device_map'] = 'auto'
39
- kwargs['torch_dtype'] = dtype
40
-
41
- return kwargs
42
-
43
-
44
- class LLM(torch.nn.Module):
45
- r"""A wrapper around a Large Language Model (LLM) from HuggingFace.
46
-
47
- model_name (str): The HuggingFace model name, *e.g.*, :obj:`"llama2"` or
48
- :obj:`"gemma"`.
49
- num_params (int): An integer representing how many parameters the
50
- HuggingFace model has, in billions. This is used to automatically
51
- allocate the correct number of GPUs needed, given the available GPU
52
- memory of your GPUs.
53
- dtype (torch.dtype, optional): The data type to use for the LLM.
54
- (default :obj: `torch.bloat16`)
55
- """
56
- def __init__(
57
- self,
58
- model_name: str,
59
- num_params: int,
60
- dtype=torch.bfloat16,
61
- ) -> None:
62
- super().__init__()
63
-
64
- from transformers import AutoModelForCausalLM, AutoTokenizer
65
-
66
- if model_name == 'llama2-7b':
67
- pretty_model_name = 'LLAMA2'
68
- model_name = 'meta-llama/Llama-2-7b-chat-hf'
69
- elif model_name == 'gemma':
70
- pretty_model_name = 'GEMMA'
71
- model_name = 'google/gemma-7b'
72
- else:
73
- pretty_model_name = model_name
74
-
75
- # A rough heuristic on GPU memory requirements, e.g., we found that
76
- # LLAMA2 (7B parameters) fits on a 85GB GPU.
77
- required_memory = 85 * num_params / 7
78
- kwargs = get_llm_kwargs(required_memory, dtype)
79
-
80
- print(f"Setting up '{pretty_model_name}' with configuration: {kwargs}")
81
- self.tokenizer = AutoTokenizer.from_pretrained(
82
- model_name,
83
- use_fast=False,
84
- )
85
- self.tokenizer.pad_token_id = PAD_TOKEN_ID
86
- self.tokenizer.padding_side = PADDING_SIDE
87
- self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
88
- self.word_embedding = self.llm.model.get_input_embeddings()
89
-
90
- if 'max_memory' not in kwargs: # Pure CPU:
91
- self.llm_device = torch.device('cpu')
92
- self.autocast_context = nullcontext()
93
- else:
94
- self.llm_device = self.llm.device
95
- self.autocast_context = torch.cuda.amp.autocast(dtype=dtype)
96
-
97
- def _encode_inputs(
98
- self,
99
- question: List[str],
100
- context: Optional[List[str]] = None,
101
- ) -> None:
102
- batch_size = len(question)
103
- questions = self.tokenizer(question, add_special_tokens=False)
104
- if context is not None:
105
- context = self.tokenizer(context, add_special_tokens=False)
106
-
107
- eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
108
- bos_token = self.tokenizer(
109
- BOS,
110
- add_special_tokens=False,
111
- return_tensors='pt',
112
- ).input_ids[0].to(self.llm_device)
113
- bos_embeds = self.word_embedding(bos_token)
114
- pad_token = torch.tensor(self.tokenizer.pad_token_id,
115
- device=self.llm_device)
116
- pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
117
- return (batch_size, questions, context, eos_user_tokens, bos_embeds,
118
- pad_embeds)
119
-
120
- def forward(
121
- self,
122
- question: List[str],
123
- answer: List[str],
124
- context: Optional[List[str]] = None,
125
- embedding: Optional[List[Tensor]] = None,
126
- ) -> Tensor:
127
- r"""The forward pass.
128
-
129
- Args:
130
- question (list[str]): The questions/prompts.
131
- answer (list[str]): The answers/labels.
132
- context (list[str], optional): Additional context to give to the
133
- LLM, such as textified knowledge graphs. (default: :obj:`None`)
134
- embedding (list[torch.Tensor], optional): RAG embedding
135
- tensors, *i.e.* the embedded form of :obj:`context`. Either
136
- :obj:`context` or :obj:`rag_embeddings` should be used, not
137
- both. (default: :obj:`None`)
138
- """
139
- if context is not None and embedding is not None:
140
- warnings.warn("Using both 'context' and 'embedding' is a waste of "
141
- "compute and memory")
142
-
143
- (batch_size, question, context, eos_user_tokens, bos_embeds,
144
- pad_embeds) = self._encode_inputs(question, context)
145
-
146
- label = self.tokenizer(answer, add_special_tokens=False)
147
- eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
148
-
149
- batch_inputs_embeds = []
150
- batch_attention_mask = []
151
- batch_label_input_ids = []
152
- for i in range(batch_size):
153
- label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
154
- label_input_ids += eos_tokens.input_ids # Add EOS token.
155
-
156
- input_ids: List[int] = []
157
- if context is not None:
158
- input_ids += context.input_ids[i][:MAX_TXT_LEN]
159
- input_ids += question.input_ids[i]
160
- input_ids += eos_user_tokens.input_ids
161
- input_ids += label_input_ids
162
-
163
- inputs_embeds = self.word_embedding(
164
- torch.tensor(input_ids, device=self.llm_device))
165
-
166
- to_cat = [bos_embeds]
167
- if embedding is not None:
168
- to_cat.append(embedding[i])
169
- to_cat.append(inputs_embeds)
170
- inputs_embeds = torch.cat(to_cat, dim=0)
171
-
172
- batch_inputs_embeds.append(inputs_embeds)
173
- batch_attention_mask.append([1] * inputs_embeds.size(0))
174
- label_input_ids = [IGNORE_INDEX] * (
175
- inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
176
- batch_label_input_ids.append(label_input_ids)
177
-
178
- # Pad input embeddings:
179
- max_length = max([x.size(0) for x in batch_inputs_embeds])
180
- for i in range(batch_size):
181
- pad = max_length - batch_inputs_embeds[i].size(0)
182
- batch_inputs_embeds[i] = torch.cat([
183
- pad_embeds.repeat(pad, 1),
184
- batch_inputs_embeds[i],
185
- ])
186
- batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
187
- batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
188
- batch_label_input_ids[i])
189
-
190
- inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
191
- attention_mask = torch.tensor(batch_attention_mask,
192
- device=self.llm_device)
193
- label_input_ids = torch.tensor(batch_label_input_ids,
194
- device=self.llm_device)
195
-
196
- with self.autocast_context:
197
- outputs = self.llm(
198
- inputs_embeds=inputs_embeds,
199
- attention_mask=attention_mask,
200
- return_dict=True,
201
- labels=label_input_ids,
202
- )
203
- return outputs.loss
204
-
205
- @torch.no_grad()
206
- def inference(
207
- self,
208
- question: List[str],
209
- context: Optional[List[str]] = None,
210
- embedding: Optional[List[Tensor]] = None,
211
- max_tokens: Optional[int] = MAX_NEW_TOKENS,
212
- ) -> List[str]:
213
- r"""The inference pass.
214
-
215
- Args:
216
- question (list[str]): The questions/prompts.
217
- answer (list[str]): The answers/labels.
218
- context (list[str], optional): Additional context to give to the
219
- LLM, such as textified knowledge graphs. (default: :obj:`None`)
220
- embedding (list[torch.Tensor], optional): RAG embedding
221
- tensors, *i.e.* the embedded form of :obj:`context`. Either
222
- :obj:`context` or :obj:`rag_embeddings` should be used, not
223
- both. (default: :obj:`None`)
224
- max_tokens (int, optional): How many tokens for the LLM to
225
- generate. (default: :obj:`32`)
226
- """
227
- if context is not None and embedding is not None:
228
- warnings.warn("Using both 'context' and 'embedding' is a waste of "
229
- "compute and memory")
230
-
231
- (batch_size, question, context, eos_user_tokens, bos_embeds,
232
- pad_embeds) = self._encode_inputs(question, context)
233
-
234
- batch_inputs_embeds = []
235
- batch_attention_mask = []
236
- for i in range(batch_size):
237
- input_ids: List[int] = []
238
- if context is not None:
239
- input_ids = context.input_ids[i][:MAX_TXT_LEN]
240
- input_ids += question.input_ids[i]
241
- input_ids += eos_user_tokens.input_ids
242
-
243
- inputs_embeds = self.word_embedding(
244
- torch.tensor(input_ids, device=self.llm_device))
245
-
246
- to_cat = [bos_embeds]
247
- if embedding is not None:
248
- to_cat.append(embedding[i])
249
- to_cat.append(inputs_embeds)
250
- inputs_embeds = torch.cat(to_cat, dim=0)
251
-
252
- batch_inputs_embeds.append(inputs_embeds)
253
- batch_attention_mask.append([1] * inputs_embeds.size(0))
254
-
255
- # Pad input embeddings:
256
- max_length = max([x.size(0) for x in batch_inputs_embeds])
257
- for i in range(batch_size):
258
- pad = max_length - batch_inputs_embeds[i].size(0)
259
- batch_inputs_embeds[i] = torch.cat([
260
- pad_embeds.repeat(pad, 1),
261
- batch_inputs_embeds[i],
262
- ])
263
- batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
264
-
265
- inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
266
- attention_mask = torch.tensor(batch_attention_mask,
267
- device=self.llm_device)
268
-
269
- bos_token = self.tokenizer(
270
- BOS,
271
- add_special_tokens=False,
272
- ).input_ids[0]
273
-
274
- with self.autocast_context:
275
- outputs = self.llm.generate(
276
- inputs_embeds=inputs_embeds,
277
- bos_token_id=bos_token,
278
- max_new_tokens=max_tokens,
279
- attention_mask=attention_mask,
280
- use_cache=True,
281
- )
282
-
283
- return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
@@ -1,94 +0,0 @@
1
- from enum import Enum
2
- from typing import List, Optional, Union
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import Tensor
7
-
8
-
9
- class PoolingStrategy(Enum):
10
- MEAN = 'mean'
11
- LAST = 'last'
12
- CLS = 'cls'
13
-
14
-
15
- class SentenceTransformer(torch.nn.Module):
16
- def __init__(
17
- self,
18
- model_name: str,
19
- pooling_strategy: Union[PoolingStrategy, str] = 'mean',
20
- ) -> None:
21
- super().__init__()
22
-
23
- self.model_name = model_name
24
- self.pooling_strategy = PoolingStrategy(pooling_strategy)
25
-
26
- from transformers import AutoModel, AutoTokenizer
27
-
28
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
29
- self.model = AutoModel.from_pretrained(model_name)
30
-
31
- def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
32
- out = self.model(input_ids=input_ids, attention_mask=attention_mask)
33
-
34
- emb = out[0] # First element contains all token embeddings.
35
- if self.pooling_strategy == PoolingStrategy.MEAN:
36
- emb = mean_pooling(emb, attention_mask)
37
- elif self.pooling_strategy == PoolingStrategy.LAST:
38
- emb = last_pooling(emb, attention_mask)
39
- else:
40
- assert self.pooling_strategy == PoolingStrategy.CLS
41
- emb = emb[:, 0, :]
42
-
43
- emb = F.normalize(emb, p=2, dim=1)
44
- return emb
45
-
46
- @property
47
- def device(self) -> torch.device:
48
- return next(iter(self.model.parameters())).device
49
-
50
- @torch.no_grad()
51
- def encode(
52
- self,
53
- text: List[str],
54
- batch_size: Optional[int] = None,
55
- output_device: Optional[torch.device] = None,
56
- ) -> Tensor:
57
- batch_size = len(text) if batch_size is None else batch_size
58
-
59
- embs: List[Tensor] = []
60
- for start in range(0, len(text), batch_size):
61
- token = self.tokenizer(
62
- text[start:start + batch_size],
63
- padding=True,
64
- truncation=True,
65
- return_tensors='pt',
66
- )
67
-
68
- emb = self(
69
- input_ids=token.input_ids.to(self.device),
70
- attention_mask=token.attention_mask.to(self.device),
71
- ).to(output_device or 'cpu')
72
-
73
- embs.append(emb)
74
-
75
- return torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
76
-
77
- def __repr__(self) -> str:
78
- return f'{self.__class__.__name__}(model_name={self.model_name})'
79
-
80
-
81
- def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
82
- mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
83
- return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
84
-
85
-
86
- def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
87
- # Check whether language model uses left padding,
88
- # which is always used for decoder LLMs
89
- left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
90
- if left_padding:
91
- return emb[:, -1]
92
-
93
- seq_indices = attention_mask.sum(dim=1) - 1
94
- return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]