plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/generators.py DELETED
@@ -1,483 +0,0 @@
1
- import logging
2
- import os
3
- import time
4
-
5
- import torch
6
- from dotenv import load_dotenv
7
- from openai import OpenAI
8
- from PIL import Image
9
- from transformers import (
10
- AutoModelForCausalLM,
11
- AutoModelForVision2Seq,
12
- AutoProcessor,
13
- AutoTokenizer,
14
- BitsAndBytesConfig,
15
- )
16
- from transformers.cache_utils import DynamicCache
17
-
18
- from plancraft.models.base import History
19
- from plancraft.models.oam import PlancraftOAM
20
- from plancraft.models.utils import (
21
- get_downloaded_models,
22
- numpy_to_base64,
23
- tokenize,
24
- )
25
-
26
-
27
- logger = logging.getLogger(__name__)
28
-
29
- load_dotenv()
30
-
31
-
32
- class TransformersGenerator:
33
- def __init__(
34
- self,
35
- model_name: str,
36
- tokenizer_name: str = "same",
37
- quantize=False,
38
- use_images=False,
39
- use_hot_cache=True,
40
- adapter_name="",
41
- **kwargs,
42
- ):
43
- self.model_name = model_name
44
- self.use_hot_cache = use_hot_cache
45
-
46
- if tokenizer_name == "same":
47
- tokenizer_name = model_name
48
-
49
- self.use_images = use_images
50
- model_name, model_kwargs = self.build_model_kwargs(
51
- model_name, quantize=quantize
52
- )
53
- self.processor = None
54
- if "idefics" in model_name:
55
- assert use_images, "Idefics model requires multimodal input"
56
- self.tokenizer = AutoProcessor.from_pretrained(
57
- tokenizer_name,
58
- **model_kwargs,
59
- )
60
- self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
61
- logger.info("Loading model")
62
- time_now = time.time()
63
- self.model = AutoModelForVision2Seq.from_pretrained(
64
- model_name,
65
- device_map="auto",
66
- **model_kwargs,
67
- )
68
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
69
- # set pad_token_id
70
- if self.tokenizer.tokenizer.pad_token_id:
71
- self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
72
- else:
73
- self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
74
- else:
75
- self.tokenizer = AutoTokenizer.from_pretrained(
76
- tokenizer_name,
77
- token=os.getenv("HF_TOKEN"), # trust_remote_code=True
78
- padding_side="left", # ensure that the padding is on the left
79
- )
80
- logger.info("Loading model")
81
- time_now = time.time()
82
- self.model = AutoModelForCausalLM.from_pretrained(
83
- model_name,
84
- device_map="auto",
85
- **model_kwargs,
86
- )
87
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
88
-
89
- # load OA adapter
90
- if adapter_name != "":
91
- logger.info(f"Loading adapter and tokenizer from {adapter_name}")
92
- self.tokenizer = AutoTokenizer.from_pretrained(
93
- adapter_name,
94
- padding_side="left",
95
- )
96
- self.model.resize_token_embeddings(len(self.tokenizer))
97
- self.model.load_adapter(adapter_name)
98
-
99
- # set pad_token_id
100
- if self.tokenizer.pad_token_id:
101
- self.pad_token_id = self.tokenizer.pad_token_id
102
- else:
103
- self.tokenizer.pad_token = self.tokenizer.eos_token
104
- self.pad_token_id = self.tokenizer.eos_token_id
105
-
106
- # compile
107
- time_now = time.time()
108
- self.model = torch.compile(self.model)
109
- logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
110
-
111
- self.model.eval()
112
- if self.pad_token_id is None:
113
- self.tokenizer.pad_token = self.tokenizer.eos_token
114
- self.model.config.pad_token_id = self.model.config.eos_token_id
115
- self.tokenizer.truncation_side = "left"
116
-
117
- self.past_key_values_kwargs = {}
118
- self.past_token_ids = None
119
-
120
- def truncate_kv_cache(self, new_token_ids: torch.Tensor):
121
- """
122
- Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
123
- Uses:
124
- past_ids: torch.Tensor [B, T]
125
- new_ids: torch.Tensor [B, T]
126
- kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
127
-
128
- NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
129
- """
130
- if (
131
- self.past_token_ids is None
132
- or "past_key_values" not in self.past_key_values_kwargs
133
- ):
134
- return
135
-
136
- # caching doesn't seem to work with multimodal models
137
- if self.use_images:
138
- self.past_key_values_kwargs = {}
139
- return
140
-
141
- past_batch_size, past_seq_len = self.past_token_ids.shape
142
- new_batch_size, new_seq_len = new_token_ids.shape
143
-
144
- # If the batch size has changed, reset the cache
145
- if past_batch_size != new_batch_size:
146
- self.past_key_values_kwargs = {}
147
- return
148
-
149
- min_shape = min(past_seq_len, new_seq_len)
150
- compare_past = (
151
- self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
152
- )
153
-
154
- # All tokens are the same - no need to truncate
155
- if not compare_past.any():
156
- return
157
-
158
- # Find the first token that is different between the past and new tokens
159
- seq_min = torch.argmax(compare_past.double(), dim=1).min()
160
-
161
- # Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
162
- # assumes shape is [num_layers, num_heads, seq_len, hidden_size]
163
- self.past_key_values_kwargs["past_key_values"] = [
164
- [kv[:, :, :seq_min, :] for kv in kvs]
165
- for kvs in self.past_key_values_kwargs["past_key_values"]
166
- ]
167
-
168
- @staticmethod
169
- def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
170
- model_kwargs = {
171
- "token": os.getenv("HF_TOKEN"),
172
- # "attn_implementation": "flash_attention_2",
173
- # "trust_remote_code": True,
174
- }
175
- quantize = kwargs.get("quantize", False)
176
- if quantize == "int4":
177
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
178
- load_in_4bit=True,
179
- )
180
- elif quantize == "int8":
181
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
182
- load_in_8bit=True,
183
- )
184
- else:
185
- model_kwargs["torch_dtype"] = torch.bfloat16
186
-
187
- downloaded_models = get_downloaded_models()
188
- if model_name in downloaded_models:
189
- model_kwargs["local_files_only"] = True
190
- model_name = downloaded_models[model_name]
191
- logger.info(f"Using local model {model_name}")
192
- if "/plancraft/outputs" in model_name:
193
- model_kwargs["local_files_only"] = True
194
- logger.info(f"Using local model {model_name}")
195
-
196
- return model_name, model_kwargs
197
-
198
- def reset(self):
199
- # NOTE: past_key_values cache with a rolling window
200
- # is not maximally useful as the beggining shifts over time
201
- # and therefore cache is invalidated
202
- self.past_key_values_kwargs = {}
203
- self.past_token_ids = None
204
-
205
- def prepare_messages(
206
- self,
207
- history: History,
208
- max_messages_window: int,
209
- system_prompt: dict = None,
210
- prompt_images: list = [],
211
- ) -> tuple[list[dict], list]:
212
- """
213
- Prepare the messages using a history
214
- """
215
- message_window = history.dialogue_history[-max_messages_window:]
216
- # remove the first assistant message if it is present
217
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
218
- message_window = message_window[1:]
219
- # add the system prompt if the first message is not a system message
220
- if message_window[0]["role"] != "system" and system_prompt is not None:
221
- message_window = [system_prompt] + message_window
222
-
223
- image_window = []
224
- if self.use_images:
225
- image_list = prompt_images + history.images
226
- image_count = 0
227
- # iterate through the messages in reverse order to assign images
228
- for m in message_window:
229
- for content in m["content"]:
230
- if content["type"] == "image":
231
- image_count += 1
232
- assert image_count <= len(image_list), "Too many images"
233
- image_window = image_list[-image_count:]
234
- image_window = [Image.fromarray(img) for img in image_window]
235
-
236
- return message_window, image_window
237
-
238
- @torch.inference_mode()
239
- def generate_unconstrained(
240
- self,
241
- batch_messages: list[list[dict]],
242
- start_messages_generation: str = "",
243
- max_tokens: int = 256,
244
- temperature=0.6,
245
- **kwargs,
246
- ) -> tuple[list[str], int]:
247
- """
248
- Generate unconstrained text based on the batch of messages.
249
- """
250
- if self.use_images:
251
- assert "images" in kwargs, "Images required for multimodal model"
252
-
253
- tokenized_messages = tokenize(
254
- self.model,
255
- self.tokenizer,
256
- batch_messages,
257
- start_messages_generation=[start_messages_generation] * len(batch_messages),
258
- max_tokens=max_tokens,
259
- images=kwargs.get("images") if self.use_images else None,
260
- )
261
- prompt_tokens = tokenized_messages["input_ids"].shape[-1]
262
-
263
- # Sent to the same device as model
264
- tokenized_messages = {
265
- k: v.to(self.model.device) for k, v in tokenized_messages.items()
266
- }
267
-
268
- # Truncate the key-value cache
269
- self.truncate_kv_cache(tokenized_messages["input_ids"])
270
-
271
- if (
272
- "past_key_values" in self.past_key_values_kwargs
273
- and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
274
- > tokenized_messages["input_ids"].shape[-1]
275
- ):
276
- raise ValueError("Past key values are larger than the input_ids")
277
-
278
- past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
279
- if past_key_values is not None:
280
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
281
-
282
- generated_sequences = self.model.generate(
283
- **tokenized_messages,
284
- do_sample=True,
285
- temperature=temperature,
286
- max_new_tokens=max_tokens,
287
- pad_token_id=self.pad_token_id,
288
- return_dict_in_generate=True,
289
- use_cache=True,
290
- past_key_values=past_key_values,
291
- return_legacy_cache=True,
292
- )
293
- # Cache the past key values
294
- if self.use_hot_cache:
295
- self.past_key_values_kwargs["past_key_values"] = (
296
- generated_sequences.past_key_values
297
- )
298
- self.past_token_ids = generated_sequences.sequences
299
-
300
- # Decode the output
301
- text_responses = self.tokenizer.batch_decode(
302
- generated_sequences.sequences[:, prompt_tokens:],
303
- skip_special_tokens=False,
304
- )
305
-
306
- text_responses = [
307
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
308
- ]
309
-
310
- _, total_tokens_used = generated_sequences.sequences.shape
311
- return text_responses, total_tokens_used
312
-
313
-
314
- class OpenAIGenerator:
315
- def __init__(self, use_images=False, model_name="gpt-4o-mini"):
316
- self.client = OpenAI()
317
- self.use_images = use_images
318
- self.model_name = model_name
319
-
320
- def reset(self):
321
- pass
322
-
323
- def prepare_messages(
324
- self,
325
- history: History,
326
- max_messages_window: int,
327
- system_prompt: dict = None,
328
- prompt_images: list = [],
329
- ) -> tuple[list[dict], list]:
330
- """
331
- Prepare the image messages for the model
332
- """
333
- message_window = history.dialogue_history[-max_messages_window:]
334
- # remove the first assistant message if it is present
335
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
336
- message_window = message_window[1:]
337
- # add the system prompt if the first message is not a system message
338
- if message_window[0]["role"] != "system" and system_prompt is not None:
339
- message_window = [system_prompt] + message_window
340
-
341
- if self.use_images:
342
- image_list = prompt_images + history.images
343
- img_idx = -1
344
- seen_images = 0
345
- # iterate through the messages in reverse order to assign images
346
- for i in range(len(message_window) - 1, -1, -1):
347
- new_content_list = []
348
- for content in message_window[i]["content"]:
349
- if content["type"] == "text":
350
- new_content_list.append(content)
351
- elif content["type"] == "image":
352
- base64_image = numpy_to_base64(image_list[img_idx])
353
- img_idx -= 1
354
- seen_images + 1
355
- new_content = {
356
- "type": "image_url",
357
- "image_url": {
358
- "url": f"data:image/jpeg;base64,{base64_image}"
359
- },
360
- }
361
- new_content_list.append(new_content)
362
- message_window[i]["content"] = new_content_list
363
- assert seen_images <= len(image_list), "Too many images"
364
-
365
- return message_window, []
366
-
367
- def generate_unconstrained(
368
- self,
369
- batch_messages: list[list[dict]],
370
- max_tokens=256,
371
- **kwargs,
372
- ) -> tuple[list[str], int]:
373
- contents = []
374
- tokens_used = 0
375
- for messages in batch_messages:
376
- response = self.client.chat.completions.create(
377
- model=self.model_name,
378
- messages=messages,
379
- temperature=0.0,
380
- max_tokens=max_tokens,
381
- top_p=1,
382
- frequency_penalty=0,
383
- presence_penalty=0,
384
- stop=["\n", "\n\n"],
385
- )
386
- content = response.choices[0].message.content
387
- tokens_used += response.usage.total_tokens
388
- contents.append(content)
389
- return contents, tokens_used
390
-
391
-
392
- class OAMGenerator:
393
- def __init__(
394
- self,
395
- model_name,
396
- ):
397
- self.model_name = model_name
398
- logger.info("Loading model")
399
- time_now = time.time()
400
- self.model = PlancraftOAM.from_pretrained(model_name)
401
- self.model.cuda()
402
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
403
- # compile
404
- time_now = time.time()
405
- self.model = torch.compile(self.model)
406
- logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
407
-
408
- self.model.eval()
409
- self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
410
- self.model.tokenizer.truncation_side = "left"
411
-
412
- def reset(self):
413
- pass
414
-
415
- def prepare_messages(
416
- self,
417
- history: History,
418
- max_messages_window: int,
419
- system_prompt: dict = None,
420
- prompt_images: list = [],
421
- ) -> tuple[list[dict], list]:
422
- """
423
- Prepare the messages using a history
424
- """
425
- message_window = history.dialogue_history[-max_messages_window:]
426
- # remove the first assistant message if it is present
427
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
428
- message_window = message_window[1:]
429
- # add the system prompt if the first message is not a system message
430
- if message_window[0]["role"] != "system" and system_prompt is not None:
431
- message_window = [system_prompt] + message_window
432
-
433
- image_window = []
434
-
435
- image_list = prompt_images + history.images
436
- image_count = 0
437
-
438
- # iterate through the messages to count how many images are present
439
- new_message_window = []
440
- for m in message_window:
441
- message_content = m["content"]
442
- message_role = m["role"]
443
-
444
- if "\ninventory:\n" in message_content:
445
- message_content = (
446
- message_content.split("\ninventory:\n")[0]
447
- + "\ninventory:<|inventory|>"
448
- )
449
-
450
- if "<|inventory|>" in message_content:
451
- image_count += 1
452
-
453
- new_message_window.append(
454
- {"role": message_role, "content": message_content}
455
- )
456
-
457
- assert image_count <= len(image_list), "Too many images"
458
- # get messages from end of queue
459
- image_window = image_list[-image_count:]
460
- image_window = [Image.fromarray(img) for img in image_window]
461
-
462
- return new_message_window, image_window
463
-
464
- @torch.inference_mode()
465
- def generate_unconstrained(
466
- self,
467
- batch_messages: list[list[dict]],
468
- images: list,
469
- max_tokens: int = 256,
470
- temperature=0.6,
471
- **kwargs,
472
- ) -> tuple[list[str], int]:
473
- """
474
- Generate unconstrained text based on the batch of messages.
475
- """
476
- text_responses, total_tokens_used = self.model.generate(
477
- batch_messages=batch_messages,
478
- batch_images=images,
479
- do_sample=True,
480
- temperature=temperature,
481
- max_new_tokens=max_tokens,
482
- )
483
- return text_responses, total_tokens_used