pyg-nightly 2.7.0.dev20250904__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
  2. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +34 -27
  3. torch_geometric/__init__.py +1 -1
  4. torch_geometric/data/__init__.py +0 -5
  5. torch_geometric/data/lightning/datamodule.py +2 -2
  6. torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
  7. torch_geometric/datasets/web_qsp_dataset.py +262 -210
  8. torch_geometric/graphgym/imports.py +2 -2
  9. torch_geometric/llm/__init__.py +9 -0
  10. torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
  11. torch_geometric/llm/models/__init__.py +23 -0
  12. torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
  13. torch_geometric/{nn → llm}/models/git_mol.py +1 -1
  14. torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
  15. torch_geometric/llm/models/llm_judge.py +158 -0
  16. torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
  17. torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
  18. torch_geometric/llm/models/txt2kg.py +353 -0
  19. torch_geometric/llm/rag_loader.py +154 -0
  20. torch_geometric/llm/utils/backend_utils.py +442 -0
  21. torch_geometric/llm/utils/feature_store.py +169 -0
  22. torch_geometric/llm/utils/graph_store.py +199 -0
  23. torch_geometric/llm/utils/vectorrag.py +124 -0
  24. torch_geometric/loader/__init__.py +0 -4
  25. torch_geometric/metrics/link_pred.py +13 -2
  26. torch_geometric/nn/__init__.py +0 -1
  27. torch_geometric/nn/models/__init__.py +0 -10
  28. torch_geometric/nn/models/sgformer.py +2 -0
  29. torch_geometric/utils/cross_entropy.py +34 -13
  30. torch_geometric/loader/rag_loader.py +0 -107
  31. torch_geometric/nn/nlp/__init__.py +0 -9
  32. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
  33. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
  34. /torch_geometric/{nn → llm}/models/glem.py +0 -0
  35. /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
  36. /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -10,15 +10,17 @@ try:
10
10
  except ImportError:
11
11
  BatchEncoding = Dict
12
12
 
13
- BOS = '<s>[INST]'
14
- EOS_USER = '[/INST]'
15
- EOS = '[/s]'
16
13
  IGNORE_INDEX = -100
17
14
  MAX_TXT_LEN = 512
18
- MAX_NEW_TOKENS = 32
15
+ MAX_NEW_TOKENS = 128
19
16
  PAD_TOKEN_ID = 0
20
17
  PADDING_SIDE = 'left'
21
18
 
19
+ # legacy constants - used for Llama 2 style prompting
20
+ BOS = '<s>[INST]'
21
+ EOS_USER = '[/INST]'
22
+ EOS = '[/s]'
23
+
22
24
 
23
25
  def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:
24
26
  torch.cuda.empty_cache()
@@ -50,49 +52,89 @@ class LLM(torch.nn.Module):
50
52
  r"""A wrapper around a Large Language Model (LLM) from HuggingFace.
51
53
 
52
54
  Args:
53
- model_name (str): The HuggingFace model name, *e.g.*, :obj:`"llama2"`
54
- or :obj:`"gemma"`.
55
- num_params (int, optional): An integer representing how many parameters
55
+ model_name (str): The HuggingFace model name
56
+ num_params (float, optional): An integer representing how many params
56
57
  the HuggingFace model has, in billions. This is used to
57
- automatically allocate the correct number of GPUs needed, given the
58
- available GPU memory of your GPUs. If not specified, the number of
59
- parameters is determined using the `huggingface_hub` module.
58
+ automatically allocate the correct number of GPUs needed (using a
59
+ rough heuristic), given the available GPU memory of your GPUs. If
60
+ not specified, the number of parameters is determined using the
61
+ `huggingface_hub` module.
62
+ n_gpus (int, optional): Number of GPUs to use. Designed for advanced
63
+ users to select how many GPU's they want to set this manually and
64
+ override the automatic set up mechanism.
60
65
  dtype (torch.dtype, optional): The data type to use for the LLM.
61
66
  (default :obj: `torch.bfloat16`)
67
+ sys_prompt (str, optional): A system prompt to use for the LLM.
68
+ (default: :obj: `None`)
62
69
  """
63
70
  def __init__(
64
71
  self,
65
72
  model_name: str,
66
- num_params: Optional[int] = None,
73
+ num_params: Optional[float] = None,
74
+ n_gpus: Optional[int] = None,
67
75
  dtype: Optional[torch.dtype] = torch.bfloat16,
76
+ sys_prompt: Optional[str] = None,
68
77
  ) -> None:
69
78
  super().__init__()
70
79
 
71
80
  self.model_name = model_name
72
81
 
73
82
  from transformers import AutoModelForCausalLM, AutoTokenizer
74
-
75
- if num_params is None:
76
- from huggingface_hub import get_safetensors_metadata
77
- safetensors_metadata = get_safetensors_metadata(model_name)
78
- param_count = safetensors_metadata.parameter_count
79
- num_params = list(param_count.values())[0] // 10**9
80
-
81
- # A rough heuristic on GPU memory requirements, e.g., we found that
82
- # LLAMA2 (7B parameters) fits on a 85GB GPU.
83
- required_memory = 85 * num_params / 7
84
- kwargs = get_llm_kwargs(required_memory, dtype)
83
+ if n_gpus is None:
84
+ if num_params is None:
85
+ from huggingface_hub import get_safetensors_metadata
86
+ safetensors_metadata = get_safetensors_metadata(model_name)
87
+ param_count = safetensors_metadata.parameter_count
88
+ num_params = float(list(param_count.values())[0] // 10**9)
89
+
90
+ # A rough heuristic on GPU memory requirements, e.g., we found that
91
+ # LLAMA2 (7B parameters) fits on a 85GB GPU.
92
+ required_memory = 85 * num_params / 7
93
+ kwargs = get_llm_kwargs(required_memory, dtype)
94
+ else:
95
+ gpu_memory: List[int] = []
96
+ for i in range(n_gpus):
97
+ gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)
98
+ kwargs = dict(revision='main')
99
+ kwargs['max_memory'] = {
100
+ i: f'{memory}GiB'
101
+ for i, memory in enumerate(gpu_memory)
102
+ }
103
+ kwargs['low_cpu_mem_usage'] = True
104
+ kwargs['device_map'] = 'auto'
105
+ kwargs['torch_dtype'] = dtype
85
106
 
86
107
  print(f"Setting up '{model_name}' with configuration: {kwargs}")
87
108
  self.tokenizer = AutoTokenizer.from_pretrained(
88
109
  model_name,
89
110
  use_fast=False,
90
111
  )
91
- self.tokenizer.pad_token_id = PAD_TOKEN_ID
92
- self.tokenizer.padding_side = PADDING_SIDE
112
+ if self.tokenizer.chat_template and self.tokenizer.bos_token is None:
113
+ dummy_convo = [
114
+ {
115
+ "role": "system",
116
+ "content": "dummy"
117
+ },
118
+ {
119
+ "role": "user",
120
+ "content": "convo"
121
+ },
122
+ ]
123
+ text = self.tokenizer.apply_chat_template(
124
+ dummy_convo,
125
+ tokenize=True,
126
+ )
127
+ self.tokenizer.bos_token = self.tokenizer.decode(text[0])
128
+ if self.tokenizer.pad_token_id is None:
129
+ self.tokenizer.pad_token_id = PAD_TOKEN_ID
130
+ if self.tokenizer.padding_side is None:
131
+ self.tokenizer.padding_side = PADDING_SIDE
93
132
  self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
94
133
  self.word_embedding = self.llm.model.get_input_embeddings()
95
-
134
+ if sys_prompt is not None:
135
+ self.sys_prompt = sys_prompt
136
+ else:
137
+ self.sys_prompt = ""
96
138
  if 'max_memory' not in kwargs: # Pure CPU:
97
139
  warnings.warn("LLM is being used on CPU, which may be slow",
98
140
  stacklevel=2)
@@ -100,8 +142,12 @@ class LLM(torch.nn.Module):
100
142
  self.autocast_context = nullcontext()
101
143
  else:
102
144
  self.device = self.llm.device
103
- self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
145
+ if dtype == torch.float32:
146
+ self.autocast_context = nullcontext()
147
+ else:
148
+ self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)
104
149
 
150
+ # legacy function - used for Llama 2 style prompting
105
151
  def _encode_inputs(
106
152
  self,
107
153
  question: List[str],
@@ -135,6 +181,7 @@ class LLM(torch.nn.Module):
135
181
  label_input_ids = label_input_ids + eos_tokens.input_ids
136
182
  return label_input_ids
137
183
 
184
+ # legacy function - used for Llama 2 style prompting
138
185
  def _input_ids(
139
186
  self,
140
187
  i: int,
@@ -149,6 +196,7 @@ class LLM(torch.nn.Module):
149
196
  input_ids += eos_user_tokens.input_ids
150
197
  return input_ids
151
198
 
199
+ # legacy function - used for Llama 2 style prompting
152
200
  def _inputs_embeds(
153
201
  self,
154
202
  i: int,
@@ -208,7 +256,8 @@ class LLM(torch.nn.Module):
208
256
  device=self.device)
209
257
  return inputs_embeds, attention_mask, label_input_ids
210
258
 
211
- def _get_embeds(
259
+ # legacy function - used for Llama 2 style prompting
260
+ def _get_embeds_old(
212
261
  self,
213
262
  question: List[str],
214
263
  context: Optional[List[str]] = None,
@@ -255,6 +304,95 @@ class LLM(torch.nn.Module):
255
304
 
256
305
  return inputs_embeds, attention_mask, label_input_ids
257
306
 
307
+ def _get_embeds(
308
+ self,
309
+ question: List[str],
310
+ context: Optional[List[str]] = None,
311
+ embedding: Optional[List[Tensor]] = None,
312
+ answer: Optional[List[str]] = None,
313
+ ) -> tuple:
314
+ if not self.tokenizer.chat_template or not self.sys_prompt:
315
+ warnings.warn(
316
+ f"HuggingFace model {self.model_name} is not using a "
317
+ "chat template, using Llama 2 style prompting. Please "
318
+ "consider using a more recent model and initialize the "
319
+ "LLM with `sys_prompt`.", stacklevel=2)
320
+ return self._get_embeds_old(question, context, embedding, answer)
321
+ batch_label_input_ids = None
322
+ if answer is not None:
323
+ label = self.tokenizer(answer, add_special_tokens=False)
324
+ eos_tokens = self.tokenizer(self.tokenizer.eos_token,
325
+ add_special_tokens=False)
326
+ batch_label_input_ids = []
327
+
328
+ batch_inputs_embeds = []
329
+ batch_attention_mask = []
330
+ for i in range(len(question)):
331
+ ctx = f"{context[i]} - " if context else ""
332
+ messages = [
333
+ {
334
+ "role": "system",
335
+ "content": self.sys_prompt
336
+ },
337
+ {
338
+ "role": "user",
339
+ "content": f"{ctx} - {question[i]}"
340
+ },
341
+ ]
342
+ text = self.tokenizer.apply_chat_template(
343
+ messages,
344
+ tokenize=False,
345
+ add_generation_prompt=True,
346
+ enable_thinking=True,
347
+ )
348
+ text = text[len(self.tokenizer.bos_token):]
349
+ input_ids = self.tokenizer(text,
350
+ add_special_tokens=False).input_ids
351
+ if answer is not None:
352
+ label_input_ids = self._label_input_ids(i, label, eos_tokens)
353
+ input_ids += label_input_ids
354
+ else:
355
+ label_input_ids = None
356
+
357
+ bos_token = self.tokenizer(
358
+ self.tokenizer.bos_token,
359
+ add_special_tokens=False,
360
+ return_tensors='pt',
361
+ ).input_ids[0].to(self.device)
362
+
363
+ bos_embeds = self.word_embedding(bos_token)
364
+
365
+ inputs_embeds = self.word_embedding(
366
+ torch.tensor(input_ids, device=self.device))
367
+
368
+ to_cat = [bos_embeds]
369
+ if embedding is not None and embedding[i] is not None:
370
+ to_cat.append(embedding[i])
371
+ to_cat.append(inputs_embeds)
372
+ inputs_embeds = torch.cat(to_cat, dim=0).to(self.device)
373
+
374
+ (
375
+ batch_inputs_embeds,
376
+ batch_attention_mask,
377
+ batch_label_input_ids,
378
+ ) = self._append_embeds(
379
+ inputs_embeds,
380
+ batch_inputs_embeds,
381
+ batch_attention_mask,
382
+ label_input_ids,
383
+ batch_label_input_ids,
384
+ )
385
+
386
+ pad_token = torch.tensor(self.tokenizer.pad_token_id,
387
+ device=self.device)
388
+ pad_embeds = self.word_embedding(pad_token).unsqueeze(0)
389
+
390
+ inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
391
+ pad_embeds, batch_inputs_embeds, batch_attention_mask,
392
+ batch_label_input_ids)
393
+
394
+ return inputs_embeds, attention_mask, label_input_ids
395
+
258
396
  def forward(
259
397
  self,
260
398
  question: List[str],
@@ -311,17 +449,13 @@ class LLM(torch.nn.Module):
311
449
  inputs_embeds, attention_mask, _ = self._get_embeds(
312
450
  question, context, embedding)
313
451
 
314
- bos_token = self.tokenizer(
315
- BOS,
316
- add_special_tokens=False,
317
- ).input_ids[0]
318
-
319
452
  with self.autocast_context:
320
453
  outputs = self.llm.generate(
321
454
  inputs_embeds=inputs_embeds,
322
- bos_token_id=bos_token,
455
+ bos_token_id=self.tokenizer.bos_token_id,
323
456
  max_new_tokens=max_tokens,
324
457
  attention_mask=attention_mask,
458
+ pad_token_id=self.tokenizer.eos_token_id,
325
459
  use_cache=True,
326
460
  )
327
461
 
@@ -0,0 +1,158 @@
1
+ from math import isnan
2
+ from typing import Optional
3
+
4
+ from torch_geometric.llm.models.txt2kg import \
5
+ _chunk_to_triples_str_cloud as call_NIM
6
+
7
+ # Credit for original "Marlin Accuracy" system goes to:
8
+ # Gilberto Titericz (NVIDIA)
9
+ # This work is an adaptation of his for PyG
10
+ SYSTEM_PROMPT_1 = (
11
+ "Instruction: You are a world class state of the art " +
12
+ "assistant for rating " +
13
+ "a User Answer given a Question. The Question is completely" +
14
+ " answered by the Reference Answer.\n" +
15
+ "Say 4, if User Answer is full contained and equivalent to" +
16
+ " Reference Answer" +
17
+ "in all terms, topics, numbers, metrics, dates and units.\n" +
18
+ "Say 2, if User Answer is partially contained and almost " +
19
+ "equivalent to Reference Answer" +
20
+ "in all terms, topics, numbers, metrics, dates and units.\n" +
21
+ "Say 0, if User Answer is not contained in Reference Answer" +
22
+ " or not accurate in all terms, topics," +
23
+ "numbers, metrics, dates and units or the User Answer do not" +
24
+ " answer the question.\n" +
25
+ "Do not explain or justify your rating. Your rating must be " +
26
+ "only 4, 2 or 0 according to the instructions above.\n" +
27
+ "### Question: \"{question}\"\n" + "### User Answer: \"{model_pred}\"\n" +
28
+ "### Reference Answer: \"{correct_answer}\"\n" + "The rating is:\n")
29
+
30
+ SYSTEM_PROMPT_2 = (
31
+ "I will rate the User Answer in comparison to the Reference " +
32
+ "Answer for a given Question.\n" +
33
+ "A rating of 4 indicates that the User Answer is entirely " +
34
+ "consistent with the Reference Answer, covering all aspects," +
35
+ " topics, numbers, metrics, dates, and units.\n" +
36
+ "A rating of 2 signifies that the User Answer is mostly " +
37
+ "aligned with the Reference Answer, with minor discrepancies" +
38
+ " in some areas.\n" +
39
+ "A rating of 0 means that the User Answer is either " +
40
+ "inaccurate, incomplete, or unrelated to the Reference " +
41
+ "Answer, or it fails to address the Question.\n" +
42
+ "I will provide the rating without any explanation or " +
43
+ "justification, adhering to the following scale: " +
44
+ "0 (no match), 2 (partial match), 4 (exact match).\n" +
45
+ "Do not explain or justify my rating. My rating must" +
46
+ " be only 4, 2 or 0 only.\n\n" + "Question: \"{question}\"\n\n" +
47
+ "Reference Answer: \"{model_pred}\"\n\n" +
48
+ "User Answer: \"{correct_answer}\"\n\n" + "Rating: ")
49
+
50
+
51
+ # TODO: add support for Local LM
52
+ # TODO: add multiproc support like txt2kg
53
+ class LLMJudge():
54
+ """Uses NIMs to score a triple of (question, model_pred, correct_answer)
55
+ This whole class is an adaptation of Gilberto's work for PyG.
56
+
57
+ Args:
58
+ NVIDIA_NIM_MODEL : (str, optional)
59
+ The name of the NVIDIA NIM model to use.
60
+ (default: "nvidia/llama-3.1-nemotron-70b-instruct").
61
+ NVIDIA_API_KEY : (str, optional)
62
+ The API key for accessing NVIDIA's NIM models.
63
+ (default: "").
64
+ ENDPOINT_URL : (str, optional)
65
+ The URL hosting your model, in case you are not using
66
+ the public NIM.
67
+ (default: "https://integrate.api.nvidia.com/v1").
68
+ """
69
+ def __init__(
70
+ self,
71
+ NVIDIA_NIM_MODEL: Optional[
72
+ str] = "nvidia/llama-3.1-nemotron-70b-instruct",
73
+ NVIDIA_API_KEY: Optional[str] = "",
74
+ ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1",
75
+ ) -> None:
76
+ self.NVIDIA_API_KEY = NVIDIA_API_KEY
77
+ self.NIM_MODEL = NVIDIA_NIM_MODEL
78
+ self.ENDPOINT_URL = ENDPOINT_URL
79
+
80
+ def _process_score(self, response: str) -> float:
81
+ """Uses 3 and 1 even though prompt says only 0, 2, 4.
82
+ This is because LLMs don't always follow instructions.
83
+ Credit to Gilberto.
84
+ """
85
+ for i in [4, 3, 2, 1, 0]:
86
+ if str(i) in response:
87
+ return i / 4
88
+ return float("nan")
89
+
90
+ def _average_scores(self, score0: float, score1: float):
91
+ """Take the average of score0 and score1.
92
+ Sometimes the LLM fail to respond or have no score in the response.
93
+ In those cases the failed score is discarded.
94
+ Credit to Gilberto.
95
+
96
+ Args:
97
+ score0 (float): judge accuracy score.
98
+ score1 (float): judge accuracy score by permuting agent answer and
99
+ ground truth.
100
+
101
+ Returns:
102
+ (float) average of score0 and score1 of both contains scores,
103
+ otherwise pick the max.
104
+ """
105
+ score = float("nan")
106
+ if score0 >= 0 and score1 >= 0:
107
+ score = (score0 + score1) / 2
108
+ else:
109
+ score = max(score0, score1)
110
+ return score
111
+
112
+ def score(
113
+ self,
114
+ question: str,
115
+ model_pred: str,
116
+ correct_answer: str,
117
+ ) -> float:
118
+ """Args:
119
+ question (str): The original question asked to the model.
120
+ model_pred (str): The prediction made by the model.
121
+ correct_answer (str): The actual correct answer to the question.
122
+
123
+ Returns:
124
+ score (float): score of 0-1, may be nan due to LLM judge failure.
125
+ Evals should skip nan's when aggregating score.
126
+ """
127
+ prompt1 = SYSTEM_PROMPT_1.format(question=question,
128
+ model_pred=model_pred,
129
+ correct_answer=correct_answer)
130
+ prompt2 = SYSTEM_PROMPT_2.format(question=question,
131
+ model_pred=model_pred,
132
+ correct_answer=correct_answer)
133
+ score1 = float("nan")
134
+ score2 = float("nan")
135
+ for _retry in range(200):
136
+ try:
137
+ score1 = self._process_score(
138
+ call_NIM(prompt1, self.NVIDIA_API_KEY, self.NIM_MODEL,
139
+ self.ENDPOINT_URL, post_text=""))
140
+ if not isnan(score1):
141
+ break
142
+ except ImportError:
143
+ raise
144
+ except: # noqa
145
+ pass
146
+ for _retry in range(20):
147
+ try:
148
+ score2 = self._process_score(
149
+ call_NIM(prompt2, self.NVIDIA_API_KEY, self.NIM_MODEL,
150
+ self.ENDPOINT_URL, post_text=""))
151
+ if not isnan(score2):
152
+ break
153
+ except ImportError:
154
+ raise
155
+ except: # noqa
156
+ pass
157
+
158
+ return self._average_scores(score1, score2)
@@ -3,8 +3,8 @@ from typing import List, Optional
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
+ from torch_geometric.llm.models.llm import BOS, LLM, MAX_NEW_TOKENS
6
7
  from torch_geometric.nn.attention import QFormer
7
- from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
8
8
  from torch_geometric.utils import to_dense_batch
9
9
 
10
10
 
@@ -4,6 +4,7 @@ from typing import List, Optional, Union
4
4
  import torch
5
5
  import torch.nn.functional as F
6
6
  from torch import Tensor
7
+ from tqdm import tqdm
7
8
 
8
9
 
9
10
  class PoolingStrategy(Enum):
@@ -31,6 +32,19 @@ class SentenceTransformer(torch.nn.Module):
31
32
  if self.tokenizer.pad_token is None:
32
33
  self.tokenizer.pad_token = self.tokenizer.eos_token
33
34
 
35
+ # Maximum sequence length from the model configuration (e.g. 8192 for
36
+ # models like ModernBERT)
37
+ self.max_seq_length = self.model.config.max_position_embeddings
38
+ """
39
+ Some models define a max sequence length in their configuration. Others
40
+ only in the tokenizer. This is a hacky heuristic to find the max
41
+ sequence length that works for the model.
42
+ """
43
+ probe_tokens = self.tokenizer("hacky heuristic", padding='max_length',
44
+ return_tensors='pt')
45
+ self.max_seq_length = min(self.max_seq_length,
46
+ probe_tokens.input_ids.shape[1])
47
+
34
48
  def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
35
49
  out = self.model(input_ids=input_ids, attention_mask=attention_mask)
36
50
 
@@ -67,6 +81,7 @@ class SentenceTransformer(torch.nn.Module):
67
81
  padding=True,
68
82
  truncation=True,
69
83
  return_tensors='pt',
84
+ max_length=self.max_seq_length,
70
85
  )
71
86
  input_ids.append(token.input_ids.to(self.device))
72
87
  attention_masks.append(token.attention_mask.to(self.device))
@@ -88,6 +103,7 @@ class SentenceTransformer(torch.nn.Module):
88
103
  text: List[str],
89
104
  batch_size: Optional[int] = None,
90
105
  output_device: Optional[Union[torch.device, str]] = None,
106
+ verbose=False,
91
107
  ) -> Tensor:
92
108
  is_empty = len(text) == 0
93
109
  text = ['dummy'] if is_empty else text
@@ -95,20 +111,38 @@ class SentenceTransformer(torch.nn.Module):
95
111
  batch_size = len(text) if batch_size is None else batch_size
96
112
 
97
113
  embs: List[Tensor] = []
98
- for start in range(0, len(text), batch_size):
114
+ loader = range(0, len(text), batch_size)
115
+ if verbose:
116
+ loader = tqdm(
117
+ loader, desc="Encoding " + str(len(text)) +
118
+ " strings w/ SentenceTransformer")
119
+ for start in loader:
99
120
  token = self.tokenizer(
100
121
  text[start:start + batch_size],
101
122
  padding=True,
102
123
  truncation=True,
103
124
  return_tensors='pt',
125
+ max_length=self.max_seq_length,
104
126
  )
105
-
106
- emb = self(
107
- input_ids=token.input_ids.to(self.device),
108
- attention_mask=token.attention_mask.to(self.device),
109
- ).to(output_device)
110
-
111
- embs.append(emb)
127
+ try:
128
+ emb = self(
129
+ input_ids=token.input_ids.to(self.device),
130
+ attention_mask=token.attention_mask.to(self.device),
131
+ ).to(output_device)
132
+
133
+ embs.append(emb)
134
+ except: # noqa
135
+ # fallback to using CPU for huge strings that cause OOMs
136
+ print("Sentence Transformer failed on cuda, trying w/ cpu...")
137
+ previous_device = self.device
138
+ self.model = self.model.to("cpu")
139
+ emb = self(
140
+ input_ids=token.input_ids.to(self.device),
141
+ attention_mask=token.attention_mask.to(self.device),
142
+ ).to(output_device)
143
+
144
+ embs.append(emb)
145
+ self.model = self.model.to(previous_device)
112
146
 
113
147
  out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
114
148
  out = out[:0] if is_empty else out