pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +32 -25
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +0 -5
- torch_geometric/data/lightning/datamodule.py +2 -2
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/web_qsp_dataset.py +262 -210
- torch_geometric/graphgym/imports.py +2 -2
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
- torch_geometric/{nn → llm}/models/git_mol.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/backend_utils.py +442 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +124 -0
- torch_geometric/loader/__init__.py +0 -4
- torch_geometric/nn/__init__.py +0 -1
- torch_geometric/nn/models/__init__.py +0 -10
- torch_geometric/nn/models/sgformer.py +2 -0
- torch_geometric/loader/rag_loader.py +0 -107
- torch_geometric/nn/nlp/__init__.py +0 -9
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
- /torch_geometric/{nn → llm}/models/glem.py +0 -0
- /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
- /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -10,15 +10,17 @@ try:
|
|
10
10
|
except ImportError:
|
11
11
|
BatchEncoding = Dict
|
12
12
|
|
13
|
-
BOS = '<s>[INST]'
|
14
|
-
EOS_USER = '[/INST]'
|
15
|
-
EOS = '[/s]'
|
16
13
|
IGNORE_INDEX = -100
|
17
14
|
MAX_TXT_LEN = 512
|
18
|
-
MAX_NEW_TOKENS =
|
15
|
+
MAX_NEW_TOKENS = 128
|
19
16
|
PAD_TOKEN_ID = 0
|
20
17
|
PADDING_SIDE = 'left'
|
21
18
|
|
19
|
+
# legacy constants - used for Llama 2 style prompting
|
20
|
+
BOS = '<s>[INST]'
|
21
|
+
EOS_USER = '[/INST]'
|
22
|
+
EOS = '[/s]'
|
23
|
+
|
22
24
|
|
23
25
|
def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
|
24
26
|
torch.cuda.empty_cache()
|
@@ -50,49 +52,89 @@ class LLM(torch.nn.Module):
|
|
50
52
|
r"""A wrapper around a Large Language Model (LLM) from HuggingFace.
|
51
53
|
|
52
54
|
Args:
|
53
|
-
model_name (str): The HuggingFace model name
|
54
|
-
|
55
|
-
num_params (int, optional): An integer representing how many parameters
|
55
|
+
model_name (str): The HuggingFace model name
|
56
|
+
num_params (float, optional): An integer representing how many params
|
56
57
|
the HuggingFace model has, in billions. This is used to
|
57
|
-
automatically allocate the correct number of GPUs needed
|
58
|
-
available GPU memory of your GPUs.
|
59
|
-
parameters is determined using the
|
58
|
+
automatically allocate the correct number of GPUs needed (using a
|
59
|
+
rough heuristic), given the available GPU memory of your GPUs. If
|
60
|
+
not specified, the number of parameters is determined using the
|
61
|
+
`huggingface_hub` module.
|
62
|
+
n_gpus (int, optional): Number of GPUs to use. Designed for advanced
|
63
|
+
users to select how many GPU's they want to set this manually and
|
64
|
+
override the automatic set up mechanism.
|
60
65
|
dtype (torch.dtype, optional): The data type to use for the LLM.
|
61
66
|
(default :obj: `torch.bfloat16`)
|
67
|
+
sys_prompt (str, optional): A system prompt to use for the LLM.
|
68
|
+
(default: :obj: `None`)
|
62
69
|
"""
|
63
70
|
def __init__(
|
64
71
|
self,
|
65
72
|
model_name: str,
|
66
|
-
num_params: Optional[
|
73
|
+
num_params: Optional[float] = None,
|
74
|
+
n_gpus: Optional[int] = None,
|
67
75
|
dtype: Optional[torch.dtype] = torch.bfloat16,
|
76
|
+
sys_prompt: Optional[str] = None,
|
68
77
|
) -> None:
|
69
78
|
super().__init__()
|
70
79
|
|
71
80
|
self.model_name = model_name
|
72
81
|
|
73
82
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
83
|
+
if n_gpus is None:
|
84
|
+
if num_params is None:
|
85
|
+
from huggingface_hub import get_safetensors_metadata
|
86
|
+
safetensors_metadata = get_safetensors_metadata(model_name)
|
87
|
+
param_count = safetensors_metadata.parameter_count
|
88
|
+
num_params = float(list(param_count.values())[0] // 10**9)
|
89
|
+
|
90
|
+
# A rough heuristic on GPU memory requirements, e.g., we found that
|
91
|
+
# LLAMA2 (7B parameters) fits on a 85GB GPU.
|
92
|
+
required_memory = 85 * num_params / 7
|
93
|
+
kwargs = get_llm_kwargs(required_memory, dtype)
|
94
|
+
else:
|
95
|
+
gpu_memory: List[int] = []
|
96
|
+
for i in range(n_gpus):
|
97
|
+
gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
|
98
|
+
kwargs = dict(revision='main')
|
99
|
+
kwargs['max_memory'] = {
|
100
|
+
i: f'{memory}GiB'
|
101
|
+
for i, memory in enumerate(gpu_memory)
|
102
|
+
}
|
103
|
+
kwargs['low_cpu_mem_usage'] = True
|
104
|
+
kwargs['device_map'] = 'auto'
|
105
|
+
kwargs['torch_dtype'] = dtype
|
85
106
|
|
86
107
|
print(f"Setting up '{model_name}' with configuration: {kwargs}")
|
87
108
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
88
109
|
model_name,
|
89
110
|
use_fast=False,
|
90
111
|
)
|
91
|
-
self.tokenizer.
|
92
|
-
|
112
|
+
if self.tokenizer.chat_template and self.tokenizer.bos_token is None:
|
113
|
+
dummy_convo = [
|
114
|
+
{
|
115
|
+
"role": "system",
|
116
|
+
"content": "dummy"
|
117
|
+
},
|
118
|
+
{
|
119
|
+
"role": "user",
|
120
|
+
"content": "convo"
|
121
|
+
},
|
122
|
+
]
|
123
|
+
text = self.tokenizer.apply_chat_template(
|
124
|
+
dummy_convo,
|
125
|
+
tokenize=True,
|
126
|
+
)
|
127
|
+
self.tokenizer.bos_token = self.tokenizer.decode(text[0])
|
128
|
+
if self.tokenizer.pad_token_id is None:
|
129
|
+
self.tokenizer.pad_token_id = PAD_TOKEN_ID
|
130
|
+
if self.tokenizer.padding_side is None:
|
131
|
+
self.tokenizer.padding_side = PADDING_SIDE
|
93
132
|
self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
|
94
133
|
self.word_embedding = self.llm.model.get_input_embeddings()
|
95
|
-
|
134
|
+
if sys_prompt is not None:
|
135
|
+
self.sys_prompt = sys_prompt
|
136
|
+
else:
|
137
|
+
self.sys_prompt = ""
|
96
138
|
if 'max_memory' not in kwargs: # Pure CPU:
|
97
139
|
warnings.warn("LLM is being used on CPU, which may be slow",
|
98
140
|
stacklevel=2)
|
@@ -100,8 +142,12 @@ class LLM(torch.nn.Module):
|
|
100
142
|
self.autocast_context = nullcontext()
|
101
143
|
else:
|
102
144
|
self.device = self.llm.device
|
103
|
-
|
145
|
+
if dtype == torch.float32:
|
146
|
+
self.autocast_context = nullcontext()
|
147
|
+
else:
|
148
|
+
self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
|
104
149
|
|
150
|
+
# legacy function - used for Llama 2 style prompting
|
105
151
|
def _encode_inputs(
|
106
152
|
self,
|
107
153
|
question: List[str],
|
@@ -135,6 +181,7 @@ class LLM(torch.nn.Module):
|
|
135
181
|
label_input_ids = label_input_ids + eos_tokens.input_ids
|
136
182
|
return label_input_ids
|
137
183
|
|
184
|
+
# legacy function - used for Llama 2 style prompting
|
138
185
|
def _input_ids(
|
139
186
|
self,
|
140
187
|
i: int,
|
@@ -149,6 +196,7 @@ class LLM(torch.nn.Module):
|
|
149
196
|
input_ids += eos_user_tokens.input_ids
|
150
197
|
return input_ids
|
151
198
|
|
199
|
+
# legacy function - used for Llama 2 style prompting
|
152
200
|
def _inputs_embeds(
|
153
201
|
self,
|
154
202
|
i: int,
|
@@ -208,7 +256,8 @@ class LLM(torch.nn.Module):
|
|
208
256
|
device=self.device)
|
209
257
|
return inputs_embeds, attention_mask, label_input_ids
|
210
258
|
|
211
|
-
|
259
|
+
# legacy function - used for Llama 2 style prompting
|
260
|
+
def _get_embeds_old(
|
212
261
|
self,
|
213
262
|
question: List[str],
|
214
263
|
context: Optional[List[str]] = None,
|
@@ -255,6 +304,95 @@ class LLM(torch.nn.Module):
|
|
255
304
|
|
256
305
|
return inputs_embeds, attention_mask, label_input_ids
|
257
306
|
|
307
|
+
def _get_embeds(
|
308
|
+
self,
|
309
|
+
question: List[str],
|
310
|
+
context: Optional[List[str]] = None,
|
311
|
+
embedding: Optional[List[Tensor]] = None,
|
312
|
+
answer: Optional[List[str]] = None,
|
313
|
+
) -> tuple:
|
314
|
+
if not self.tokenizer.chat_template or not self.sys_prompt:
|
315
|
+
warnings.warn(
|
316
|
+
f"HuggingFace model {self.model_name} is not using a "
|
317
|
+
"chat template, using Llama 2 style prompting. Please "
|
318
|
+
"consider using a more recent model and initialize the "
|
319
|
+
"LLM with `sys_prompt`.", stacklevel=2)
|
320
|
+
return self._get_embeds_old(question, context, embedding, answer)
|
321
|
+
batch_label_input_ids = None
|
322
|
+
if answer is not None:
|
323
|
+
label = self.tokenizer(answer, add_special_tokens=False)
|
324
|
+
eos_tokens = self.tokenizer(self.tokenizer.eos_token,
|
325
|
+
add_special_tokens=False)
|
326
|
+
batch_label_input_ids = []
|
327
|
+
|
328
|
+
batch_inputs_embeds = []
|
329
|
+
batch_attention_mask = []
|
330
|
+
for i in range(len(question)):
|
331
|
+
ctx = f"{context[i]} - " if context else ""
|
332
|
+
messages = [
|
333
|
+
{
|
334
|
+
"role": "system",
|
335
|
+
"content": self.sys_prompt
|
336
|
+
},
|
337
|
+
{
|
338
|
+
"role": "user",
|
339
|
+
"content": f"{ctx} - {question[i]}"
|
340
|
+
},
|
341
|
+
]
|
342
|
+
text = self.tokenizer.apply_chat_template(
|
343
|
+
messages,
|
344
|
+
tokenize=False,
|
345
|
+
add_generation_prompt=True,
|
346
|
+
enable_thinking=True,
|
347
|
+
)
|
348
|
+
text = text[len(self.tokenizer.bos_token):]
|
349
|
+
input_ids = self.tokenizer(text,
|
350
|
+
add_special_tokens=False).input_ids
|
351
|
+
if answer is not None:
|
352
|
+
label_input_ids = self._label_input_ids(i, label, eos_tokens)
|
353
|
+
input_ids += label_input_ids
|
354
|
+
else:
|
355
|
+
label_input_ids = None
|
356
|
+
|
357
|
+
bos_token = self.tokenizer(
|
358
|
+
self.tokenizer.bos_token,
|
359
|
+
add_special_tokens=False,
|
360
|
+
return_tensors='pt',
|
361
|
+
).input_ids[0].to(self.device)
|
362
|
+
|
363
|
+
bos_embeds = self.word_embedding(bos_token)
|
364
|
+
|
365
|
+
inputs_embeds = self.word_embedding(
|
366
|
+
torch.tensor(input_ids, device=self.device))
|
367
|
+
|
368
|
+
to_cat = [bos_embeds]
|
369
|
+
if embedding is not None and embedding[i] is not None:
|
370
|
+
to_cat.append(embedding[i])
|
371
|
+
to_cat.append(inputs_embeds)
|
372
|
+
inputs_embeds = torch.cat(to_cat, dim=0).to(self.device)
|
373
|
+
|
374
|
+
(
|
375
|
+
batch_inputs_embeds,
|
376
|
+
batch_attention_mask,
|
377
|
+
batch_label_input_ids,
|
378
|
+
) = self._append_embeds(
|
379
|
+
inputs_embeds,
|
380
|
+
batch_inputs_embeds,
|
381
|
+
batch_attention_mask,
|
382
|
+
label_input_ids,
|
383
|
+
batch_label_input_ids,
|
384
|
+
)
|
385
|
+
|
386
|
+
pad_token = torch.tensor(self.tokenizer.pad_token_id,
|
387
|
+
device=self.device)
|
388
|
+
pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
|
389
|
+
|
390
|
+
inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
|
391
|
+
pad_embeds, batch_inputs_embeds, batch_attention_mask,
|
392
|
+
batch_label_input_ids)
|
393
|
+
|
394
|
+
return inputs_embeds, attention_mask, label_input_ids
|
395
|
+
|
258
396
|
def forward(
|
259
397
|
self,
|
260
398
|
question: List[str],
|
@@ -311,17 +449,13 @@ class LLM(torch.nn.Module):
|
|
311
449
|
inputs_embeds, attention_mask, _ = self._get_embeds(
|
312
450
|
question, context, embedding)
|
313
451
|
|
314
|
-
bos_token = self.tokenizer(
|
315
|
-
BOS,
|
316
|
-
add_special_tokens=False,
|
317
|
-
).input_ids[0]
|
318
|
-
|
319
452
|
with self.autocast_context:
|
320
453
|
outputs = self.llm.generate(
|
321
454
|
inputs_embeds=inputs_embeds,
|
322
|
-
bos_token_id=
|
455
|
+
bos_token_id=self.tokenizer.bos_token_id,
|
323
456
|
max_new_tokens=max_tokens,
|
324
457
|
attention_mask=attention_mask,
|
458
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
325
459
|
use_cache=True,
|
326
460
|
)
|
327
461
|
|
@@ -0,0 +1,158 @@
|
|
1
|
+
from math import isnan
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from torch_geometric.llm.models.txt2kg import \
|
5
|
+
_chunk_to_triples_str_cloud as call_NIM
|
6
|
+
|
7
|
+
# Credit for original "Marlin Accuracy" system goes to:
|
8
|
+
# Gilberto Titericz (NVIDIA)
|
9
|
+
# This work is an adaptation of his for PyG
|
10
|
+
SYSTEM_PROMPT_1 = (
|
11
|
+
"Instruction: You are a world class state of the art " +
|
12
|
+
"assistant for rating " +
|
13
|
+
"a User Answer given a Question. The Question is completely" +
|
14
|
+
" answered by the Reference Answer.\n" +
|
15
|
+
"Say 4, if User Answer is full contained and equivalent to" +
|
16
|
+
" Reference Answer" +
|
17
|
+
"in all terms, topics, numbers, metrics, dates and units.\n" +
|
18
|
+
"Say 2, if User Answer is partially contained and almost " +
|
19
|
+
"equivalent to Reference Answer" +
|
20
|
+
"in all terms, topics, numbers, metrics, dates and units.\n" +
|
21
|
+
"Say 0, if User Answer is not contained in Reference Answer" +
|
22
|
+
" or not accurate in all terms, topics," +
|
23
|
+
"numbers, metrics, dates and units or the User Answer do not" +
|
24
|
+
" answer the question.\n" +
|
25
|
+
"Do not explain or justify your rating. Your rating must be " +
|
26
|
+
"only 4, 2 or 0 according to the instructions above.\n" +
|
27
|
+
"### Question: \"{question}\"\n" + "### User Answer: \"{model_pred}\"\n" +
|
28
|
+
"### Reference Answer: \"{correct_answer}\"\n" + "The rating is:\n")
|
29
|
+
|
30
|
+
SYSTEM_PROMPT_2 = (
|
31
|
+
"I will rate the User Answer in comparison to the Reference " +
|
32
|
+
"Answer for a given Question.\n" +
|
33
|
+
"A rating of 4 indicates that the User Answer is entirely " +
|
34
|
+
"consistent with the Reference Answer, covering all aspects," +
|
35
|
+
" topics, numbers, metrics, dates, and units.\n" +
|
36
|
+
"A rating of 2 signifies that the User Answer is mostly " +
|
37
|
+
"aligned with the Reference Answer, with minor discrepancies" +
|
38
|
+
" in some areas.\n" +
|
39
|
+
"A rating of 0 means that the User Answer is either " +
|
40
|
+
"inaccurate, incomplete, or unrelated to the Reference " +
|
41
|
+
"Answer, or it fails to address the Question.\n" +
|
42
|
+
"I will provide the rating without any explanation or " +
|
43
|
+
"justification, adhering to the following scale: " +
|
44
|
+
"0 (no match), 2 (partial match), 4 (exact match).\n" +
|
45
|
+
"Do not explain or justify my rating. My rating must" +
|
46
|
+
" be only 4, 2 or 0 only.\n\n" + "Question: \"{question}\"\n\n" +
|
47
|
+
"Reference Answer: \"{model_pred}\"\n\n" +
|
48
|
+
"User Answer: \"{correct_answer}\"\n\n" + "Rating: ")
|
49
|
+
|
50
|
+
|
51
|
+
# TODO: add support for Local LM
|
52
|
+
# TODO: add multiproc support like txt2kg
|
53
|
+
class LLMJudge():
|
54
|
+
"""Uses NIMs to score a triple of (question, model_pred, correct_answer)
|
55
|
+
This whole class is an adaptation of Gilberto's work for PyG.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
NVIDIA_NIM_MODEL : (str, optional)
|
59
|
+
The name of the NVIDIA NIM model to use.
|
60
|
+
(default: "nvidia/llama-3.1-nemotron-70b-instruct").
|
61
|
+
NVIDIA_API_KEY : (str, optional)
|
62
|
+
The API key for accessing NVIDIA's NIM models.
|
63
|
+
(default: "").
|
64
|
+
ENDPOINT_URL : (str, optional)
|
65
|
+
The URL hosting your model, in case you are not using
|
66
|
+
the public NIM.
|
67
|
+
(default: "https://integrate.api.nvidia.com/v1").
|
68
|
+
"""
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
NVIDIA_NIM_MODEL: Optional[
|
72
|
+
str] = "nvidia/llama-3.1-nemotron-70b-instruct",
|
73
|
+
NVIDIA_API_KEY: Optional[str] = "",
|
74
|
+
ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
|
75
|
+
) -> None:
|
76
|
+
self.NVIDIA_API_KEY = NVIDIA_API_KEY
|
77
|
+
self.NIM_MODEL = NVIDIA_NIM_MODEL
|
78
|
+
self.ENDPOINT_URL = ENDPOINT_URL
|
79
|
+
|
80
|
+
def _process_score(self, response: str) -> float:
|
81
|
+
"""Uses 3 and 1 even though prompt says only 0, 2, 4.
|
82
|
+
This is because LLMs don't always follow instructions.
|
83
|
+
Credit to Gilberto.
|
84
|
+
"""
|
85
|
+
for i in [4, 3, 2, 1, 0]:
|
86
|
+
if str(i) in response:
|
87
|
+
return i / 4
|
88
|
+
return float("nan")
|
89
|
+
|
90
|
+
def _average_scores(self, score0: float, score1: float):
|
91
|
+
"""Take the average of score0 and score1.
|
92
|
+
Sometimes the LLM fail to respond or have no score in the response.
|
93
|
+
In those cases the failed score is discarded.
|
94
|
+
Credit to Gilberto.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
score0 (float): judge accuracy score.
|
98
|
+
score1 (float): judge accuracy score by permuting agent answer and
|
99
|
+
ground truth.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
(float) average of score0 and score1 of both contains scores,
|
103
|
+
otherwise pick the max.
|
104
|
+
"""
|
105
|
+
score = float("nan")
|
106
|
+
if score0 >= 0 and score1 >= 0:
|
107
|
+
score = (score0 + score1) / 2
|
108
|
+
else:
|
109
|
+
score = max(score0, score1)
|
110
|
+
return score
|
111
|
+
|
112
|
+
def score(
|
113
|
+
self,
|
114
|
+
question: str,
|
115
|
+
model_pred: str,
|
116
|
+
correct_answer: str,
|
117
|
+
) -> float:
|
118
|
+
"""Args:
|
119
|
+
question (str): The original question asked to the model.
|
120
|
+
model_pred (str): The prediction made by the model.
|
121
|
+
correct_answer (str): The actual correct answer to the question.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
score (float): score of 0-1, may be nan due to LLM judge failure.
|
125
|
+
Evals should skip nan's when aggregating score.
|
126
|
+
"""
|
127
|
+
prompt1 = SYSTEM_PROMPT_1.format(question=question,
|
128
|
+
model_pred=model_pred,
|
129
|
+
correct_answer=correct_answer)
|
130
|
+
prompt2 = SYSTEM_PROMPT_2.format(question=question,
|
131
|
+
model_pred=model_pred,
|
132
|
+
correct_answer=correct_answer)
|
133
|
+
score1 = float("nan")
|
134
|
+
score2 = float("nan")
|
135
|
+
for _retry in range(200):
|
136
|
+
try:
|
137
|
+
score1 = self._process_score(
|
138
|
+
call_NIM(prompt1, self.NVIDIA_API_KEY, self.NIM_MODEL,
|
139
|
+
self.ENDPOINT_URL, post_text=""))
|
140
|
+
if not isnan(score1):
|
141
|
+
break
|
142
|
+
except ImportError:
|
143
|
+
raise
|
144
|
+
except: # noqa
|
145
|
+
pass
|
146
|
+
for _retry in range(20):
|
147
|
+
try:
|
148
|
+
score2 = self._process_score(
|
149
|
+
call_NIM(prompt2, self.NVIDIA_API_KEY, self.NIM_MODEL,
|
150
|
+
self.ENDPOINT_URL, post_text=""))
|
151
|
+
if not isnan(score2):
|
152
|
+
break
|
153
|
+
except ImportError:
|
154
|
+
raise
|
155
|
+
except: # noqa
|
156
|
+
pass
|
157
|
+
|
158
|
+
return self._average_scores(score1, score2)
|
@@ -3,8 +3,8 @@ from typing import List, Optional
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
+
from torch_geometric.llm.models.llm import BOS, LLM, MAX_NEW_TOKENS
|
6
7
|
from torch_geometric.nn.attention import QFormer
|
7
|
-
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
|
8
8
|
from torch_geometric.utils import to_dense_batch
|
9
9
|
|
10
10
|
|
@@ -4,6 +4,7 @@ from typing import List, Optional, Union
|
|
4
4
|
import torch
|
5
5
|
import torch.nn.functional as F
|
6
6
|
from torch import Tensor
|
7
|
+
from tqdm import tqdm
|
7
8
|
|
8
9
|
|
9
10
|
class PoolingStrategy(Enum):
|
@@ -31,6 +32,19 @@ class SentenceTransformer(torch.nn.Module):
|
|
31
32
|
if self.tokenizer.pad_token is None:
|
32
33
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
33
34
|
|
35
|
+
# Maximum sequence length from the model configuration (e.g. 8192 for
|
36
|
+
# models like ModernBERT)
|
37
|
+
self.max_seq_length = self.model.config.max_position_embeddings
|
38
|
+
"""
|
39
|
+
Some models define a max sequence length in their configuration. Others
|
40
|
+
only in the tokenizer. This is a hacky heuristic to find the max
|
41
|
+
sequence length that works for the model.
|
42
|
+
"""
|
43
|
+
probe_tokens = self.tokenizer("hacky heuristic", padding='max_length',
|
44
|
+
return_tensors='pt')
|
45
|
+
self.max_seq_length = min(self.max_seq_length,
|
46
|
+
probe_tokens.input_ids.shape[1])
|
47
|
+
|
34
48
|
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
35
49
|
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
36
50
|
|
@@ -67,6 +81,7 @@ class SentenceTransformer(torch.nn.Module):
|
|
67
81
|
padding=True,
|
68
82
|
truncation=True,
|
69
83
|
return_tensors='pt',
|
84
|
+
max_length=self.max_seq_length,
|
70
85
|
)
|
71
86
|
input_ids.append(token.input_ids.to(self.device))
|
72
87
|
attention_masks.append(token.attention_mask.to(self.device))
|
@@ -88,6 +103,7 @@ class SentenceTransformer(torch.nn.Module):
|
|
88
103
|
text: List[str],
|
89
104
|
batch_size: Optional[int] = None,
|
90
105
|
output_device: Optional[Union[torch.device, str]] = None,
|
106
|
+
verbose=False,
|
91
107
|
) -> Tensor:
|
92
108
|
is_empty = len(text) == 0
|
93
109
|
text = ['dummy'] if is_empty else text
|
@@ -95,20 +111,38 @@ class SentenceTransformer(torch.nn.Module):
|
|
95
111
|
batch_size = len(text) if batch_size is None else batch_size
|
96
112
|
|
97
113
|
embs: List[Tensor] = []
|
98
|
-
|
114
|
+
loader = range(0, len(text), batch_size)
|
115
|
+
if verbose:
|
116
|
+
loader = tqdm(
|
117
|
+
loader, desc="Encoding " + str(len(text)) +
|
118
|
+
" strings w/ SentenceTransformer")
|
119
|
+
for start in loader:
|
99
120
|
token = self.tokenizer(
|
100
121
|
text[start:start + batch_size],
|
101
122
|
padding=True,
|
102
123
|
truncation=True,
|
103
124
|
return_tensors='pt',
|
125
|
+
max_length=self.max_seq_length,
|
104
126
|
)
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
127
|
+
try:
|
128
|
+
emb = self(
|
129
|
+
input_ids=token.input_ids.to(self.device),
|
130
|
+
attention_mask=token.attention_mask.to(self.device),
|
131
|
+
).to(output_device)
|
132
|
+
|
133
|
+
embs.append(emb)
|
134
|
+
except: # noqa
|
135
|
+
# fallback to using CPU for huge strings that cause OOMs
|
136
|
+
print("Sentence Transformer failed on cuda, trying w/ cpu...")
|
137
|
+
previous_device = self.device
|
138
|
+
self.model = self.model.to("cpu")
|
139
|
+
emb = self(
|
140
|
+
input_ids=token.input_ids.to(self.device),
|
141
|
+
attention_mask=token.attention_mask.to(self.device),
|
142
|
+
).to(output_device)
|
143
|
+
|
144
|
+
embs.append(emb)
|
145
|
+
self.model = self.model.to(previous_device)
|
112
146
|
|
113
147
|
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
|
114
148
|
out = out[:0] if is_empty else out
|