pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,470 @@
1
+ import warnings
2
+ from contextlib import nullcontext
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ try:
9
+ from transformers.tokenization_utils_base import BatchEncoding
10
+ except ImportError:
11
+ BatchEncoding = Dict
12
+
13
+ IGNORE_INDEX = -100
14
+ MAX_TXT_LEN = 512
15
+ MAX_NEW_TOKENS = 128
16
+ PAD_TOKEN_ID = 0
17
+ PADDING_SIDE = 'left'
18
+
19
+ # legacy constants - used for Llama 2 style prompting
20
+ BOS = '<s>[INST]'
21
+ EOS_USER = '[/INST]'
22
+ EOS = '[/s]'
23
+
24
+
25
+ def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
26
+ torch.cuda.empty_cache()
27
+
28
+ gpu_memory: List[int] = []
29
+ for i in range(torch.cuda.device_count()):
30
+ gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
31
+ # Use the minimum number of GPUs to fit the LLM on.
32
+ if sum(gpu_memory) >= required_memory:
33
+ break
34
+
35
+ if sum(gpu_memory) < required_memory:
36
+ gpu_memory = [] # If not enough VRAM, use pure CPU.
37
+
38
+ kwargs = dict(revision='main')
39
+ if len(gpu_memory) > 0:
40
+ kwargs['max_memory'] = {
41
+ i: f'{memory}GiB'
42
+ for i, memory in enumerate(gpu_memory)
43
+ }
44
+ kwargs['low_cpu_mem_usage'] = True
45
+ kwargs['device_map'] = 'auto'
46
+ kwargs['torch_dtype'] = dtype
47
+
48
+ return kwargs
49
+
50
+
51
+ class LLM(torch.nn.Module):
52
+ r"""A wrapper around a Large Language Model (LLM) from HuggingFace.
53
+
54
+ Args:
55
+ model_name (str): The HuggingFace model name
56
+ num_params (float, optional): An integer representing how many params
57
+ the HuggingFace model has, in billions. This is used to
58
+ automatically allocate the correct number of GPUs needed (using a
59
+ rough heuristic), given the available GPU memory of your GPUs. If
60
+ not specified, the number of parameters is determined using the
61
+ `huggingface_hub` module.
62
+ n_gpus (int, optional): Number of GPUs to use. Designed for advanced
63
+ users to select how many GPU's they want to set this manually and
64
+ override the automatic set up mechanism.
65
+ dtype (torch.dtype, optional): The data type to use for the LLM.
66
+ (default :obj: `torch.bfloat16`)
67
+ sys_prompt (str, optional): A system prompt to use for the LLM.
68
+ (default: :obj: `None`)
69
+ """
70
+ def __init__(
71
+ self,
72
+ model_name: str,
73
+ num_params: Optional[float] = None,
74
+ n_gpus: Optional[int] = None,
75
+ dtype: Optional[torch.dtype] = torch.bfloat16,
76
+ sys_prompt: Optional[str] = None,
77
+ ) -> None:
78
+ super().__init__()
79
+
80
+ self.model_name = model_name
81
+
82
+ from transformers import AutoModelForCausalLM, AutoTokenizer
83
+ if n_gpus is None:
84
+ if num_params is None:
85
+ from huggingface_hub import get_safetensors_metadata
86
+ safetensors_metadata = get_safetensors_metadata(model_name)
87
+ param_count = safetensors_metadata.parameter_count
88
+ num_params = float(list(param_count.values())[0] // 10**9)
89
+
90
+ # A rough heuristic on GPU memory requirements, e.g., we found that
91
+ # LLAMA3 (8B parameters) fits on a 96GB GPU.
92
+ required_memory = 96.0 * num_params / 8.0
93
+ kwargs = get_llm_kwargs(required_memory, dtype)
94
+ else:
95
+ gpu_memory: List[int] = []
96
+ for i in range(n_gpus):
97
+ gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
98
+ kwargs = dict(revision='main')
99
+ kwargs['max_memory'] = {
100
+ i: f'{memory}GiB'
101
+ for i, memory in enumerate(gpu_memory)
102
+ }
103
+ kwargs['low_cpu_mem_usage'] = True
104
+ kwargs['device_map'] = 'auto'
105
+ kwargs['torch_dtype'] = dtype
106
+
107
+ print(f"Setting up '{model_name}' with configuration: {kwargs}")
108
+ self.tokenizer = AutoTokenizer.from_pretrained(
109
+ model_name,
110
+ use_fast=False,
111
+ )
112
+ if self.tokenizer.chat_template and self.tokenizer.bos_token is None:
113
+ dummy_convo = [
114
+ {
115
+ "role": "system",
116
+ "content": "dummy"
117
+ },
118
+ {
119
+ "role": "user",
120
+ "content": "convo"
121
+ },
122
+ ]
123
+ text = self.tokenizer.apply_chat_template(
124
+ dummy_convo,
125
+ tokenize=True,
126
+ )
127
+ self.tokenizer.bos_token = self.tokenizer.decode(text[0])
128
+ if self.tokenizer.pad_token_id is None:
129
+ self.tokenizer.pad_token_id = PAD_TOKEN_ID
130
+ if self.tokenizer.padding_side is None:
131
+ self.tokenizer.padding_side = PADDING_SIDE
132
+ self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
133
+ self.word_embedding = self.llm.model.get_input_embeddings()
134
+ if sys_prompt is not None:
135
+ self.sys_prompt = sys_prompt
136
+ else:
137
+ self.sys_prompt = ""
138
+ if 'max_memory' not in kwargs: # Pure CPU:
139
+ warnings.warn(
140
+ "LLM is being used on CPU, which may be slow. This decision "
141
+ "was made by a rough hueristic that assumes your GPU set up "
142
+ "does not have enough GPU RAM. This is done to avoid GPU OOM "
143
+ "errors. If you think this is a mistake, please initialize "
144
+ "your LLM with the n_gpus param to dictate how many gpus to "
145
+ "use for the LLM.", stacklevel=2)
146
+ self.device = torch.device('cpu')
147
+ self.autocast_context = nullcontext()
148
+ else:
149
+ self.device = self.llm.device
150
+ if dtype == torch.float32:
151
+ self.autocast_context = nullcontext()
152
+ else:
153
+ self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
154
+
155
+ # legacy function - used for Llama 2 style prompting
156
+ def _encode_inputs(
157
+ self,
158
+ question: List[str],
159
+ context: Optional[List[str]] = None,
160
+ ) -> tuple:
161
+ batch_size = len(question)
162
+ questions = self.tokenizer(question, add_special_tokens=False)
163
+ if context is not None:
164
+ context = self.tokenizer(context, add_special_tokens=False)
165
+
166
+ eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
167
+ bos_token = self.tokenizer(
168
+ BOS,
169
+ add_special_tokens=False,
170
+ return_tensors='pt',
171
+ ).input_ids[0].to(self.device)
172
+ bos_embeds = self.word_embedding(bos_token)
173
+ pad_token = torch.tensor(self.tokenizer.pad_token_id,
174
+ device=self.device)
175
+ pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
176
+ return (batch_size, questions, context, eos_user_tokens, bos_embeds,
177
+ pad_embeds)
178
+
179
+ def _label_input_ids(
180
+ self,
181
+ i: int,
182
+ label: BatchEncoding,
183
+ eos_tokens: BatchEncoding,
184
+ ) -> List[int]:
185
+ label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
186
+ label_input_ids = label_input_ids + eos_tokens.input_ids
187
+ return label_input_ids
188
+
189
+ # legacy function - used for Llama 2 style prompting
190
+ def _input_ids(
191
+ self,
192
+ i: int,
193
+ context: BatchEncoding,
194
+ question: BatchEncoding,
195
+ eos_user_tokens: BatchEncoding,
196
+ ) -> List[int]:
197
+ input_ids: List[int] = []
198
+ if context is not None:
199
+ input_ids += context.input_ids[i][:MAX_TXT_LEN]
200
+ input_ids += question.input_ids[i]
201
+ input_ids += eos_user_tokens.input_ids
202
+ return input_ids
203
+
204
+ # legacy function - used for Llama 2 style prompting
205
+ def _inputs_embeds(
206
+ self,
207
+ i: int,
208
+ input_ids: List[int],
209
+ bos_embeds: Tensor,
210
+ embedding: Optional[List[Tensor]] = None,
211
+ ) -> Tensor:
212
+ inputs_embeds = self.word_embedding(
213
+ torch.tensor(input_ids, device=self.device))
214
+
215
+ to_cat = [bos_embeds]
216
+ if embedding is not None and embedding[i] is not None:
217
+ to_cat.append(embedding[i])
218
+ to_cat.append(inputs_embeds)
219
+ return torch.cat(to_cat, dim=0).to(self.device)
220
+
221
+ def _append_embeds(
222
+ self,
223
+ inputs_embeds: Tensor,
224
+ batch_inputs_embeds: List[Tensor],
225
+ batch_attention_mask: List[List[int]],
226
+ label_input_ids: List[int] = None,
227
+ batch_label_input_ids: Optional[List[List[int]]] = None,
228
+ ) -> tuple:
229
+ batch_inputs_embeds.append(inputs_embeds)
230
+ batch_attention_mask.append([1] * inputs_embeds.size(0))
231
+ if label_input_ids is not None:
232
+ pad = inputs_embeds.size(0) - len(label_input_ids)
233
+ label_input_ids = [IGNORE_INDEX] * pad + label_input_ids
234
+ batch_label_input_ids.append(label_input_ids)
235
+ return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids
236
+
237
+ def _pad_embeds(
238
+ self,
239
+ pad_embeds: Tensor,
240
+ batch_inputs_embeds: List[Tensor],
241
+ batch_attention_mask: List[List[int]],
242
+ batch_label_input_ids: Optional[List[List[int]]] = None,
243
+ ) -> tuple:
244
+ max_length = max([x.size(0) for x in batch_inputs_embeds])
245
+ batch_size = len(batch_inputs_embeds)
246
+ for i in range(batch_size):
247
+ pad = max_length - batch_inputs_embeds[i].size(0)
248
+ batch_inputs_embeds[i] = torch.cat([
249
+ pad_embeds.repeat(pad, 1),
250
+ batch_inputs_embeds[i],
251
+ ])
252
+ batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
253
+ if batch_label_input_ids is not None:
254
+ tmp = [IGNORE_INDEX] * pad + batch_label_input_ids[i]
255
+ batch_label_input_ids[i] = tmp
256
+ inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
257
+ attention_mask = torch.tensor(batch_attention_mask, device=self.device)
258
+ label_input_ids = None
259
+ if batch_label_input_ids is not None:
260
+ label_input_ids = torch.tensor(batch_label_input_ids,
261
+ device=self.device)
262
+ return inputs_embeds, attention_mask, label_input_ids
263
+
264
+ # legacy function - used for Llama 2 style prompting
265
+ def _get_embeds_old(
266
+ self,
267
+ question: List[str],
268
+ context: Optional[List[str]] = None,
269
+ embedding: Optional[List[Tensor]] = None,
270
+ answer: Optional[List[str]] = None,
271
+ ) -> tuple:
272
+ (batch_size, question, context, eos_user_tokens, bos_embeds,
273
+ pad_embeds) = self._encode_inputs(question, context)
274
+
275
+ batch_label_input_ids = None
276
+ if answer is not None:
277
+ label = self.tokenizer(answer, add_special_tokens=False)
278
+ eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
279
+ batch_label_input_ids = []
280
+
281
+ batch_inputs_embeds = []
282
+ batch_attention_mask = []
283
+ for i in range(batch_size):
284
+ input_ids = self._input_ids(i, context, question, eos_user_tokens)
285
+ if answer is not None:
286
+ label_input_ids = self._label_input_ids(i, label, eos_tokens)
287
+ input_ids += label_input_ids
288
+ else:
289
+ label_input_ids = None
290
+
291
+ inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds,
292
+ embedding)
293
+
294
+ (
295
+ batch_inputs_embeds,
296
+ batch_attention_mask,
297
+ batch_label_input_ids,
298
+ ) = self._append_embeds(
299
+ inputs_embeds,
300
+ batch_inputs_embeds,
301
+ batch_attention_mask,
302
+ label_input_ids,
303
+ batch_label_input_ids,
304
+ )
305
+
306
+ inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
307
+ pad_embeds, batch_inputs_embeds, batch_attention_mask,
308
+ batch_label_input_ids)
309
+
310
+ return inputs_embeds, attention_mask, label_input_ids
311
+
312
+ def _get_embeds(
313
+ self,
314
+ question: List[str],
315
+ context: Optional[List[str]] = None,
316
+ embedding: Optional[List[Tensor]] = None,
317
+ answer: Optional[List[str]] = None,
318
+ ) -> tuple:
319
+ if not self.tokenizer.chat_template or not self.sys_prompt:
320
+ warnings.warn(
321
+ f"HuggingFace model {self.model_name} is not using a "
322
+ "chat template, using Llama 2 style prompting. Please "
323
+ "consider using a more recent model and initialize the "
324
+ "LLM with `sys_prompt`.", stacklevel=2)
325
+ return self._get_embeds_old(question, context, embedding, answer)
326
+ batch_label_input_ids = None
327
+ if answer is not None:
328
+ label = self.tokenizer(answer, add_special_tokens=False)
329
+ eos_tokens = self.tokenizer(self.tokenizer.eos_token,
330
+ add_special_tokens=False)
331
+ batch_label_input_ids = []
332
+
333
+ batch_inputs_embeds = []
334
+ batch_attention_mask = []
335
+ for i in range(len(question)):
336
+ ctx = f"{context[i]} - " if context else ""
337
+ messages = [
338
+ {
339
+ "role": "system",
340
+ "content": self.sys_prompt
341
+ },
342
+ {
343
+ "role": "user",
344
+ "content": f"{ctx} - {question[i]}"
345
+ },
346
+ ]
347
+ text = self.tokenizer.apply_chat_template(
348
+ messages,
349
+ tokenize=False,
350
+ add_generation_prompt=True,
351
+ enable_thinking=True,
352
+ )
353
+ text = text[len(self.tokenizer.bos_token):]
354
+ input_ids = self.tokenizer(text,
355
+ add_special_tokens=False).input_ids
356
+ if answer is not None:
357
+ label_input_ids = self._label_input_ids(i, label, eos_tokens)
358
+ input_ids += label_input_ids
359
+ else:
360
+ label_input_ids = None
361
+
362
+ bos_token = self.tokenizer(
363
+ self.tokenizer.bos_token,
364
+ add_special_tokens=False,
365
+ return_tensors='pt',
366
+ ).input_ids[0].to(self.device)
367
+
368
+ bos_embeds = self.word_embedding(bos_token)
369
+
370
+ inputs_embeds = self.word_embedding(
371
+ torch.tensor(input_ids, device=self.device))
372
+
373
+ to_cat = [bos_embeds]
374
+ if embedding is not None and embedding[i] is not None:
375
+ to_cat.append(embedding[i])
376
+ to_cat.append(inputs_embeds)
377
+ inputs_embeds = torch.cat(to_cat, dim=0).to(self.device)
378
+
379
+ (
380
+ batch_inputs_embeds,
381
+ batch_attention_mask,
382
+ batch_label_input_ids,
383
+ ) = self._append_embeds(
384
+ inputs_embeds,
385
+ batch_inputs_embeds,
386
+ batch_attention_mask,
387
+ label_input_ids,
388
+ batch_label_input_ids,
389
+ )
390
+
391
+ pad_token = torch.tensor(self.tokenizer.pad_token_id,
392
+ device=self.device)
393
+ pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
394
+
395
+ inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
396
+ pad_embeds, batch_inputs_embeds, batch_attention_mask,
397
+ batch_label_input_ids)
398
+
399
+ return inputs_embeds, attention_mask, label_input_ids
400
+
401
+ def forward(
402
+ self,
403
+ question: List[str],
404
+ answer: List[str],
405
+ context: Optional[List[str]] = None,
406
+ embedding: Optional[List[Tensor]] = None,
407
+ ) -> Tensor:
408
+ r"""The forward pass.
409
+
410
+ Args:
411
+ question (list[str]): The questions/prompts.
412
+ answer (list[str]): The answers/labels.
413
+ context (list[str], optional): Additional context to give to the
414
+ LLM, such as textified knowledge graphs. (default: :obj:`None`)
415
+ embedding (list[torch.Tensor], optional): RAG embedding
416
+ tensors, *i.e.* the embedded form of :obj:`context`. Either
417
+ :obj:`context` or :obj:`embedding` should be used, not
418
+ both. (default: :obj:`None`)
419
+ """
420
+ inputs_embeds, attention_mask, label_input_ids = self._get_embeds(
421
+ question, context, embedding, answer)
422
+
423
+ with self.autocast_context:
424
+ outputs = self.llm(
425
+ inputs_embeds=inputs_embeds,
426
+ attention_mask=attention_mask,
427
+ return_dict=True,
428
+ labels=label_input_ids,
429
+ )
430
+ return outputs.loss
431
+
432
+ @torch.no_grad()
433
+ def inference(
434
+ self,
435
+ question: List[str],
436
+ context: Optional[List[str]] = None,
437
+ embedding: Optional[List[Tensor]] = None,
438
+ max_tokens: Optional[int] = MAX_NEW_TOKENS,
439
+ ) -> List[str]:
440
+ r"""The inference pass.
441
+
442
+ Args:
443
+ question (list[str]): The questions/prompts.
444
+ answer (list[str]): The answers/labels.
445
+ context (list[str], optional): Additional context to give to the
446
+ LLM, such as textified knowledge graphs. (default: :obj:`None`)
447
+ embedding (list[torch.Tensor], optional): RAG embedding
448
+ tensors, *i.e.* the embedded form of :obj:`context`. Either
449
+ :obj:`context` or :obj:`embedding` should be used, not
450
+ both. (default: :obj:`None`)
451
+ max_tokens (int, optional): How many tokens for the LLM to
452
+ generate. (default: :obj:`32`)
453
+ """
454
+ inputs_embeds, attention_mask, _ = self._get_embeds(
455
+ question, context, embedding)
456
+
457
+ with self.autocast_context:
458
+ outputs = self.llm.generate(
459
+ inputs_embeds=inputs_embeds,
460
+ bos_token_id=self.tokenizer.bos_token_id,
461
+ max_new_tokens=max_tokens,
462
+ attention_mask=attention_mask,
463
+ pad_token_id=self.tokenizer.eos_token_id,
464
+ use_cache=True,
465
+ )
466
+
467
+ return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
468
+
469
+ def __repr__(self) -> str:
470
+ return f'{self.__class__.__name__}({self.model_name})'
@@ -0,0 +1,158 @@
1
+ from math import isnan
2
+ from typing import Optional
3
+
4
+ from torch_geometric.llm.models.txt2kg import \
5
+ _chunk_to_triples_str_cloud as call_NIM
6
+
7
+ # Credit for original "Marlin Accuracy" system goes to:
8
+ # Gilberto Titericz (NVIDIA)
9
+ # This work is an adaptation of his for PyG
10
+ SYSTEM_PROMPT_1 = (
11
+ "Instruction: You are a world class state of the art " +
12
+ "assistant for rating " +
13
+ "a User Answer given a Question. The Question is completely" +
14
+ " answered by the Reference Answer.\n" +
15
+ "Say 4, if User Answer is full contained and equivalent to" +
16
+ " Reference Answer" +
17
+ "in all terms, topics, numbers, metrics, dates and units.\n" +
18
+ "Say 2, if User Answer is partially contained and almost " +
19
+ "equivalent to Reference Answer" +
20
+ "in all terms, topics, numbers, metrics, dates and units.\n" +
21
+ "Say 0, if User Answer is not contained in Reference Answer" +
22
+ " or not accurate in all terms, topics," +
23
+ "numbers, metrics, dates and units or the User Answer do not" +
24
+ " answer the question.\n" +
25
+ "Do not explain or justify your rating. Your rating must be " +
26
+ "only 4, 2 or 0 according to the instructions above.\n" +
27
+ "### Question: \"{question}\"\n" + "### User Answer: \"{model_pred}\"\n" +
28
+ "### Reference Answer: \"{correct_answer}\"\n" + "The rating is:\n")
29
+
30
+ SYSTEM_PROMPT_2 = (
31
+ "I will rate the User Answer in comparison to the Reference " +
32
+ "Answer for a given Question.\n" +
33
+ "A rating of 4 indicates that the User Answer is entirely " +
34
+ "consistent with the Reference Answer, covering all aspects," +
35
+ " topics, numbers, metrics, dates, and units.\n" +
36
+ "A rating of 2 signifies that the User Answer is mostly " +
37
+ "aligned with the Reference Answer, with minor discrepancies" +
38
+ " in some areas.\n" +
39
+ "A rating of 0 means that the User Answer is either " +
40
+ "inaccurate, incomplete, or unrelated to the Reference " +
41
+ "Answer, or it fails to address the Question.\n" +
42
+ "I will provide the rating without any explanation or " +
43
+ "justification, adhering to the following scale: " +
44
+ "0 (no match), 2 (partial match), 4 (exact match).\n" +
45
+ "Do not explain or justify my rating. My rating must" +
46
+ " be only 4, 2 or 0 only.\n\n" + "Question: \"{question}\"\n\n" +
47
+ "Reference Answer: \"{model_pred}\"\n\n" +
48
+ "User Answer: \"{correct_answer}\"\n\n" + "Rating: ")
49
+
50
+
51
+ # TODO: add support for Local LM
52
+ # TODO: add multiproc support like txt2kg
53
+ class LLMJudge():
54
+ """Uses NIMs to score a triple of (question, model_pred, correct_answer)
55
+ This whole class is an adaptation of Gilberto's work for PyG.
56
+
57
+ Args:
58
+ NVIDIA_NIM_MODEL : (str, optional)
59
+ The name of the NVIDIA NIM model to use.
60
+ (default: "nvidia/llama-3.1-nemotron-70b-instruct").
61
+ NVIDIA_API_KEY : (str, optional)
62
+ The API key for accessing NVIDIA's NIM models.
63
+ (default: "").
64
+ ENDPOINT_URL : (str, optional)
65
+ The URL hosting your model, in case you are not using
66
+ the public NIM.
67
+ (default: "https://integrate.api.nvidia.com/v1").
68
+ """
69
+ def __init__(
70
+ self,
71
+ NVIDIA_NIM_MODEL: Optional[
72
+ str] = "nvidia/llama-3.1-nemotron-70b-instruct",
73
+ NVIDIA_API_KEY: Optional[str] = "",
74
+ ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
75
+ ) -> None:
76
+ self.NVIDIA_API_KEY = NVIDIA_API_KEY
77
+ self.NIM_MODEL = NVIDIA_NIM_MODEL
78
+ self.ENDPOINT_URL = ENDPOINT_URL
79
+
80
+ def _process_score(self, response: str) -> float:
81
+ """Uses 3 and 1 even though prompt says only 0, 2, 4.
82
+ This is because LLMs don't always follow instructions.
83
+ Credit to Gilberto.
84
+ """
85
+ for i in [4, 3, 2, 1, 0]:
86
+ if str(i) in response:
87
+ return i / 4
88
+ return float("nan")
89
+
90
+ def _average_scores(self, score0: float, score1: float):
91
+ """Take the average of score0 and score1.
92
+ Sometimes the LLM fail to respond or have no score in the response.
93
+ In those cases the failed score is discarded.
94
+ Credit to Gilberto.
95
+
96
+ Args:
97
+ score0 (float): judge accuracy score.
98
+ score1 (float): judge accuracy score by permuting agent answer and
99
+ ground truth.
100
+
101
+ Returns:
102
+ (float) average of score0 and score1 of both contains scores,
103
+ otherwise pick the max.
104
+ """
105
+ score = float("nan")
106
+ if score0 >= 0 and score1 >= 0:
107
+ score = (score0 + score1) / 2
108
+ else:
109
+ score = max(score0, score1)
110
+ return score
111
+
112
+ def score(
113
+ self,
114
+ question: str,
115
+ model_pred: str,
116
+ correct_answer: str,
117
+ ) -> float:
118
+ """Args:
119
+ question (str): The original question asked to the model.
120
+ model_pred (str): The prediction made by the model.
121
+ correct_answer (str): The actual correct answer to the question.
122
+
123
+ Returns:
124
+ score (float): score of 0-1, may be nan due to LLM judge failure.
125
+ Evals should skip nan's when aggregating score.
126
+ """
127
+ prompt1 = SYSTEM_PROMPT_1.format(question=question,
128
+ model_pred=model_pred,
129
+ correct_answer=correct_answer)
130
+ prompt2 = SYSTEM_PROMPT_2.format(question=question,
131
+ model_pred=model_pred,
132
+ correct_answer=correct_answer)
133
+ score1 = float("nan")
134
+ score2 = float("nan")
135
+ for _retry in range(200):
136
+ try:
137
+ score1 = self._process_score(
138
+ call_NIM(prompt1, self.NVIDIA_API_KEY, self.NIM_MODEL,
139
+ self.ENDPOINT_URL, post_text=""))
140
+ if not isnan(score1):
141
+ break
142
+ except ImportError:
143
+ raise
144
+ except: # noqa
145
+ pass
146
+ for _retry in range(20):
147
+ try:
148
+ score2 = self._process_score(
149
+ call_NIM(prompt2, self.NVIDIA_API_KEY, self.NIM_MODEL,
150
+ self.ENDPOINT_URL, post_text=""))
151
+ if not isnan(score2):
152
+ break
153
+ except ImportError:
154
+ raise
155
+ except: # noqa
156
+ pass
157
+
158
+ return self._average_scores(score1, score2)