plancraft 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/generators.py ADDED
@@ -0,0 +1,483 @@
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ from dotenv import load_dotenv
7
+ from openai import OpenAI
8
+ from PIL import Image
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoModelForVision2Seq,
12
+ AutoProcessor,
13
+ AutoTokenizer,
14
+ BitsAndBytesConfig,
15
+ )
16
+ from transformers.cache_utils import DynamicCache
17
+
18
+ from plancraft.models.base import History
19
+ from plancraft.models.oam import PlancraftOAM
20
+ from plancraft.models.utils import (
21
+ get_downloaded_models,
22
+ numpy_to_base64,
23
+ tokenize,
24
+ )
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ load_dotenv()
30
+
31
+
32
+ class TransformersGenerator:
33
+ def __init__(
34
+ self,
35
+ model_name: str,
36
+ tokenizer_name: str = "same",
37
+ quantize=False,
38
+ use_images=False,
39
+ use_hot_cache=True,
40
+ adapter_name="",
41
+ **kwargs,
42
+ ):
43
+ self.model_name = model_name
44
+ self.use_hot_cache = use_hot_cache
45
+
46
+ if tokenizer_name == "same":
47
+ tokenizer_name = model_name
48
+
49
+ self.use_images = use_images
50
+ model_name, model_kwargs = self.build_model_kwargs(
51
+ model_name, quantize=quantize
52
+ )
53
+ self.processor = None
54
+ if "idefics" in model_name:
55
+ assert use_images, "Idefics model requires multimodal input"
56
+ self.tokenizer = AutoProcessor.from_pretrained(
57
+ tokenizer_name,
58
+ **model_kwargs,
59
+ )
60
+ self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
61
+ logger.info("Loading model")
62
+ time_now = time.time()
63
+ self.model = AutoModelForVision2Seq.from_pretrained(
64
+ model_name,
65
+ device_map="auto",
66
+ **model_kwargs,
67
+ )
68
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
69
+ # set pad_token_id
70
+ if self.tokenizer.tokenizer.pad_token_id:
71
+ self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
72
+ else:
73
+ self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
74
+ else:
75
+ self.tokenizer = AutoTokenizer.from_pretrained(
76
+ tokenizer_name,
77
+ token=os.getenv("HF_TOKEN"), # trust_remote_code=True
78
+ padding_side="left", # ensure that the padding is on the left
79
+ )
80
+ logger.info("Loading model")
81
+ time_now = time.time()
82
+ self.model = AutoModelForCausalLM.from_pretrained(
83
+ model_name,
84
+ device_map="auto",
85
+ **model_kwargs,
86
+ )
87
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
88
+
89
+ # load OA adapter
90
+ if adapter_name != "":
91
+ logger.info(f"Loading adapter and tokenizer from {adapter_name}")
92
+ self.tokenizer = AutoTokenizer.from_pretrained(
93
+ adapter_name,
94
+ padding_side="left",
95
+ )
96
+ self.model.resize_token_embeddings(len(self.tokenizer))
97
+ self.model.load_adapter(adapter_name)
98
+
99
+ # set pad_token_id
100
+ if self.tokenizer.pad_token_id:
101
+ self.pad_token_id = self.tokenizer.pad_token_id
102
+ else:
103
+ self.tokenizer.pad_token = self.tokenizer.eos_token
104
+ self.pad_token_id = self.tokenizer.eos_token_id
105
+
106
+ # compile
107
+ time_now = time.time()
108
+ self.model = torch.compile(self.model)
109
+ logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
110
+
111
+ self.model.eval()
112
+ if self.pad_token_id is None:
113
+ self.tokenizer.pad_token = self.tokenizer.eos_token
114
+ self.model.config.pad_token_id = self.model.config.eos_token_id
115
+ self.tokenizer.truncation_side = "left"
116
+
117
+ self.past_key_values_kwargs = {}
118
+ self.past_token_ids = None
119
+
120
+ def truncate_kv_cache(self, new_token_ids: torch.Tensor):
121
+ """
122
+ Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
123
+ Uses:
124
+ past_ids: torch.Tensor [B, T]
125
+ new_ids: torch.Tensor [B, T]
126
+ kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
127
+
128
+ NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
129
+ """
130
+ if (
131
+ self.past_token_ids is None
132
+ or "past_key_values" not in self.past_key_values_kwargs
133
+ ):
134
+ return
135
+
136
+ # caching doesn't seem to work with multimodal models
137
+ if self.use_images:
138
+ self.past_key_values_kwargs = {}
139
+ return
140
+
141
+ past_batch_size, past_seq_len = self.past_token_ids.shape
142
+ new_batch_size, new_seq_len = new_token_ids.shape
143
+
144
+ # If the batch size has changed, reset the cache
145
+ if past_batch_size != new_batch_size:
146
+ self.past_key_values_kwargs = {}
147
+ return
148
+
149
+ min_shape = min(past_seq_len, new_seq_len)
150
+ compare_past = (
151
+ self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
152
+ )
153
+
154
+ # All tokens are the same - no need to truncate
155
+ if not compare_past.any():
156
+ return
157
+
158
+ # Find the first token that is different between the past and new tokens
159
+ seq_min = torch.argmax(compare_past.double(), dim=1).min()
160
+
161
+ # Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
162
+ # assumes shape is [num_layers, num_heads, seq_len, hidden_size]
163
+ self.past_key_values_kwargs["past_key_values"] = [
164
+ [kv[:, :, :seq_min, :] for kv in kvs]
165
+ for kvs in self.past_key_values_kwargs["past_key_values"]
166
+ ]
167
+
168
+ @staticmethod
169
+ def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
170
+ model_kwargs = {
171
+ "token": os.getenv("HF_TOKEN"),
172
+ # "attn_implementation": "flash_attention_2",
173
+ # "trust_remote_code": True,
174
+ }
175
+ quantize = kwargs.get("quantize", False)
176
+ if quantize == "int4":
177
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
178
+ load_in_4bit=True,
179
+ )
180
+ elif quantize == "int8":
181
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
182
+ load_in_8bit=True,
183
+ )
184
+ else:
185
+ model_kwargs["torch_dtype"] = torch.bfloat16
186
+
187
+ downloaded_models = get_downloaded_models()
188
+ if model_name in downloaded_models:
189
+ model_kwargs["local_files_only"] = True
190
+ model_name = downloaded_models[model_name]
191
+ logger.info(f"Using local model {model_name}")
192
+ if "/plancraft/outputs" in model_name:
193
+ model_kwargs["local_files_only"] = True
194
+ logger.info(f"Using local model {model_name}")
195
+
196
+ return model_name, model_kwargs
197
+
198
+ def reset(self):
199
+ # NOTE: past_key_values cache with a rolling window
200
+ # is not maximally useful as the beggining shifts over time
201
+ # and therefore cache is invalidated
202
+ self.past_key_values_kwargs = {}
203
+ self.past_token_ids = None
204
+
205
+ def prepare_messages(
206
+ self,
207
+ history: History,
208
+ max_messages_window: int,
209
+ system_prompt: dict = None,
210
+ prompt_images: list = [],
211
+ ) -> tuple[list[dict], list]:
212
+ """
213
+ Prepare the messages using a history
214
+ """
215
+ message_window = history.dialogue_history[-max_messages_window:]
216
+ # remove the first assistant message if it is present
217
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
218
+ message_window = message_window[1:]
219
+ # add the system prompt if the first message is not a system message
220
+ if message_window[0]["role"] != "system" and system_prompt is not None:
221
+ message_window = [system_prompt] + message_window
222
+
223
+ image_window = []
224
+ if self.use_images:
225
+ image_list = prompt_images + history.images
226
+ image_count = 0
227
+ # iterate through the messages in reverse order to assign images
228
+ for m in message_window:
229
+ for content in m["content"]:
230
+ if content["type"] == "image":
231
+ image_count += 1
232
+ assert image_count <= len(image_list), "Too many images"
233
+ image_window = image_list[-image_count:]
234
+ image_window = [Image.fromarray(img) for img in image_window]
235
+
236
+ return message_window, image_window
237
+
238
+ @torch.inference_mode()
239
+ def generate_unconstrained(
240
+ self,
241
+ batch_messages: list[list[dict]],
242
+ start_messages_generation: str = "",
243
+ max_tokens: int = 256,
244
+ temperature=0.6,
245
+ **kwargs,
246
+ ) -> tuple[list[str], int]:
247
+ """
248
+ Generate unconstrained text based on the batch of messages.
249
+ """
250
+ if self.use_images:
251
+ assert "images" in kwargs, "Images required for multimodal model"
252
+
253
+ tokenized_messages = tokenize(
254
+ self.model,
255
+ self.tokenizer,
256
+ batch_messages,
257
+ start_messages_generation=[start_messages_generation] * len(batch_messages),
258
+ max_tokens=max_tokens,
259
+ images=kwargs.get("images") if self.use_images else None,
260
+ )
261
+ prompt_tokens = tokenized_messages["input_ids"].shape[-1]
262
+
263
+ # Sent to the same device as model
264
+ tokenized_messages = {
265
+ k: v.to(self.model.device) for k, v in tokenized_messages.items()
266
+ }
267
+
268
+ # Truncate the key-value cache
269
+ self.truncate_kv_cache(tokenized_messages["input_ids"])
270
+
271
+ if (
272
+ "past_key_values" in self.past_key_values_kwargs
273
+ and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
274
+ > tokenized_messages["input_ids"].shape[-1]
275
+ ):
276
+ raise ValueError("Past key values are larger than the input_ids")
277
+
278
+ past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
279
+ if past_key_values is not None:
280
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
281
+
282
+ generated_sequences = self.model.generate(
283
+ **tokenized_messages,
284
+ do_sample=True,
285
+ temperature=temperature,
286
+ max_new_tokens=max_tokens,
287
+ pad_token_id=self.pad_token_id,
288
+ return_dict_in_generate=True,
289
+ use_cache=True,
290
+ past_key_values=past_key_values,
291
+ return_legacy_cache=True,
292
+ )
293
+ # Cache the past key values
294
+ if self.use_hot_cache:
295
+ self.past_key_values_kwargs["past_key_values"] = (
296
+ generated_sequences.past_key_values
297
+ )
298
+ self.past_token_ids = generated_sequences.sequences
299
+
300
+ # Decode the output
301
+ text_responses = self.tokenizer.batch_decode(
302
+ generated_sequences.sequences[:, prompt_tokens:],
303
+ skip_special_tokens=False,
304
+ )
305
+
306
+ text_responses = [
307
+ text_response.replace("<|eot_id|>", "") for text_response in text_responses
308
+ ]
309
+
310
+ _, total_tokens_used = generated_sequences.sequences.shape
311
+ return text_responses, total_tokens_used
312
+
313
+
314
+ class OpenAIGenerator:
315
+ def __init__(self, use_images=False, model_name="gpt-4o-mini"):
316
+ self.client = OpenAI()
317
+ self.use_images = use_images
318
+ self.model_name = model_name
319
+
320
+ def reset(self):
321
+ pass
322
+
323
+ def prepare_messages(
324
+ self,
325
+ history: History,
326
+ max_messages_window: int,
327
+ system_prompt: dict = None,
328
+ prompt_images: list = [],
329
+ ) -> tuple[list[dict], list]:
330
+ """
331
+ Prepare the image messages for the model
332
+ """
333
+ message_window = history.dialogue_history[-max_messages_window:]
334
+ # remove the first assistant message if it is present
335
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
336
+ message_window = message_window[1:]
337
+ # add the system prompt if the first message is not a system message
338
+ if message_window[0]["role"] != "system" and system_prompt is not None:
339
+ message_window = [system_prompt] + message_window
340
+
341
+ if self.use_images:
342
+ image_list = prompt_images + history.images
343
+ img_idx = -1
344
+ seen_images = 0
345
+ # iterate through the messages in reverse order to assign images
346
+ for i in range(len(message_window) - 1, -1, -1):
347
+ new_content_list = []
348
+ for content in message_window[i]["content"]:
349
+ if content["type"] == "text":
350
+ new_content_list.append(content)
351
+ elif content["type"] == "image":
352
+ base64_image = numpy_to_base64(image_list[img_idx])
353
+ img_idx -= 1
354
+ seen_images + 1
355
+ new_content = {
356
+ "type": "image_url",
357
+ "image_url": {
358
+ "url": f"data:image/jpeg;base64,{base64_image}"
359
+ },
360
+ }
361
+ new_content_list.append(new_content)
362
+ message_window[i]["content"] = new_content_list
363
+ assert seen_images <= len(image_list), "Too many images"
364
+
365
+ return message_window, []
366
+
367
+ def generate_unconstrained(
368
+ self,
369
+ batch_messages: list[list[dict]],
370
+ max_tokens=256,
371
+ **kwargs,
372
+ ) -> tuple[list[str], int]:
373
+ contents = []
374
+ tokens_used = 0
375
+ for messages in batch_messages:
376
+ response = self.client.chat.completions.create(
377
+ model=self.model_name,
378
+ messages=messages,
379
+ temperature=0.0,
380
+ max_tokens=max_tokens,
381
+ top_p=1,
382
+ frequency_penalty=0,
383
+ presence_penalty=0,
384
+ stop=["\n", "\n\n"],
385
+ )
386
+ content = response.choices[0].message.content
387
+ tokens_used += response.usage.total_tokens
388
+ contents.append(content)
389
+ return contents, tokens_used
390
+
391
+
392
+ class OAMGenerator:
393
+ def __init__(
394
+ self,
395
+ model_name,
396
+ ):
397
+ self.model_name = model_name
398
+ logger.info("Loading model")
399
+ time_now = time.time()
400
+ self.model = PlancraftOAM.from_pretrained(model_name)
401
+ self.model.cuda()
402
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
403
+ # compile
404
+ time_now = time.time()
405
+ self.model = torch.compile(self.model)
406
+ logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
407
+
408
+ self.model.eval()
409
+ self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
410
+ self.model.tokenizer.truncation_side = "left"
411
+
412
+ def reset(self):
413
+ pass
414
+
415
+ def prepare_messages(
416
+ self,
417
+ history: History,
418
+ max_messages_window: int,
419
+ system_prompt: dict = None,
420
+ prompt_images: list = [],
421
+ ) -> tuple[list[dict], list]:
422
+ """
423
+ Prepare the messages using a history
424
+ """
425
+ message_window = history.dialogue_history[-max_messages_window:]
426
+ # remove the first assistant message if it is present
427
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
428
+ message_window = message_window[1:]
429
+ # add the system prompt if the first message is not a system message
430
+ if message_window[0]["role"] != "system" and system_prompt is not None:
431
+ message_window = [system_prompt] + message_window
432
+
433
+ image_window = []
434
+
435
+ image_list = prompt_images + history.images
436
+ image_count = 0
437
+
438
+ # iterate through the messages to count how many images are present
439
+ new_message_window = []
440
+ for m in message_window:
441
+ message_content = m["content"]
442
+ message_role = m["role"]
443
+
444
+ if "\ninventory:\n" in message_content:
445
+ message_content = (
446
+ message_content.split("\ninventory:\n")[0]
447
+ + "\ninventory:<|inventory|>"
448
+ )
449
+
450
+ if "<|inventory|>" in message_content:
451
+ image_count += 1
452
+
453
+ new_message_window.append(
454
+ {"role": message_role, "content": message_content}
455
+ )
456
+
457
+ assert image_count <= len(image_list), "Too many images"
458
+ # get messages from end of queue
459
+ image_window = image_list[-image_count:]
460
+ image_window = [Image.fromarray(img) for img in image_window]
461
+
462
+ return new_message_window, image_window
463
+
464
+ @torch.inference_mode()
465
+ def generate_unconstrained(
466
+ self,
467
+ batch_messages: list[list[dict]],
468
+ images: list,
469
+ max_tokens: int = 256,
470
+ temperature=0.6,
471
+ **kwargs,
472
+ ) -> tuple[list[str], int]:
473
+ """
474
+ Generate unconstrained text based on the batch of messages.
475
+ """
476
+ text_responses, total_tokens_used = self.model.generate(
477
+ batch_messages=batch_messages,
478
+ batch_images=images,
479
+ do_sample=True,
480
+ temperature=temperature,
481
+ max_new_tokens=max_tokens,
482
+ )
483
+ return text_responses, total_tokens_used