plancraft 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
models/generators.py ADDED
@@ -0,0 +1,483 @@
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ from dotenv import load_dotenv
7
+ from openai import OpenAI
8
+ from PIL import Image
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoModelForVision2Seq,
12
+ AutoProcessor,
13
+ AutoTokenizer,
14
+ BitsAndBytesConfig,
15
+ )
16
+ from transformers.cache_utils import DynamicCache
17
+
18
+ from plancraft.models.base import History
19
+ from plancraft.models.oam import PlancraftOAM
20
+ from plancraft.models.utils import (
21
+ get_downloaded_models,
22
+ numpy_to_base64,
23
+ tokenize,
24
+ )
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ load_dotenv()
30
+
31
+
32
+ class TransformersGenerator:
33
+ def __init__(
34
+ self,
35
+ model_name: str,
36
+ tokenizer_name: str = "same",
37
+ quantize=False,
38
+ use_images=False,
39
+ use_hot_cache=True,
40
+ adapter_name="",
41
+ **kwargs,
42
+ ):
43
+ self.model_name = model_name
44
+ self.use_hot_cache = use_hot_cache
45
+
46
+ if tokenizer_name == "same":
47
+ tokenizer_name = model_name
48
+
49
+ self.use_images = use_images
50
+ model_name, model_kwargs = self.build_model_kwargs(
51
+ model_name, quantize=quantize
52
+ )
53
+ self.processor = None
54
+ if "idefics" in model_name:
55
+ assert use_images, "Idefics model requires multimodal input"
56
+ self.tokenizer = AutoProcessor.from_pretrained(
57
+ tokenizer_name,
58
+ **model_kwargs,
59
+ )
60
+ self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
61
+ logger.info("Loading model")
62
+ time_now = time.time()
63
+ self.model = AutoModelForVision2Seq.from_pretrained(
64
+ model_name,
65
+ device_map="auto",
66
+ **model_kwargs,
67
+ )
68
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
69
+ # set pad_token_id
70
+ if self.tokenizer.tokenizer.pad_token_id:
71
+ self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
72
+ else:
73
+ self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
74
+ else:
75
+ self.tokenizer = AutoTokenizer.from_pretrained(
76
+ tokenizer_name,
77
+ token=os.getenv("HF_TOKEN"), # trust_remote_code=True
78
+ padding_side="left", # ensure that the padding is on the left
79
+ )
80
+ logger.info("Loading model")
81
+ time_now = time.time()
82
+ self.model = AutoModelForCausalLM.from_pretrained(
83
+ model_name,
84
+ device_map="auto",
85
+ **model_kwargs,
86
+ )
87
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
88
+
89
+ # load OA adapter
90
+ if adapter_name != "":
91
+ logger.info(f"Loading adapter and tokenizer from {adapter_name}")
92
+ self.tokenizer = AutoTokenizer.from_pretrained(
93
+ adapter_name,
94
+ padding_side="left",
95
+ )
96
+ self.model.resize_token_embeddings(len(self.tokenizer))
97
+ self.model.load_adapter(adapter_name)
98
+
99
+ # set pad_token_id
100
+ if self.tokenizer.pad_token_id:
101
+ self.pad_token_id = self.tokenizer.pad_token_id
102
+ else:
103
+ self.tokenizer.pad_token = self.tokenizer.eos_token
104
+ self.pad_token_id = self.tokenizer.eos_token_id
105
+
106
+ # compile
107
+ time_now = time.time()
108
+ self.model = torch.compile(self.model)
109
+ logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
110
+
111
+ self.model.eval()
112
+ if self.pad_token_id is None:
113
+ self.tokenizer.pad_token = self.tokenizer.eos_token
114
+ self.model.config.pad_token_id = self.model.config.eos_token_id
115
+ self.tokenizer.truncation_side = "left"
116
+
117
+ self.past_key_values_kwargs = {}
118
+ self.past_token_ids = None
119
+
120
+ def truncate_kv_cache(self, new_token_ids: torch.Tensor):
121
+ """
122
+ Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
123
+ Uses:
124
+ past_ids: torch.Tensor [B, T]
125
+ new_ids: torch.Tensor [B, T]
126
+ kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
127
+
128
+ NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
129
+ """
130
+ if (
131
+ self.past_token_ids is None
132
+ or "past_key_values" not in self.past_key_values_kwargs
133
+ ):
134
+ return
135
+
136
+ # caching doesn't seem to work with multimodal models
137
+ if self.use_images:
138
+ self.past_key_values_kwargs = {}
139
+ return
140
+
141
+ past_batch_size, past_seq_len = self.past_token_ids.shape
142
+ new_batch_size, new_seq_len = new_token_ids.shape
143
+
144
+ # If the batch size has changed, reset the cache
145
+ if past_batch_size != new_batch_size:
146
+ self.past_key_values_kwargs = {}
147
+ return
148
+
149
+ min_shape = min(past_seq_len, new_seq_len)
150
+ compare_past = (
151
+ self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
152
+ )
153
+
154
+ # All tokens are the same - no need to truncate
155
+ if not compare_past.any():
156
+ return
157
+
158
+ # Find the first token that is different between the past and new tokens
159
+ seq_min = torch.argmax(compare_past.double(), dim=1).min()
160
+
161
+ # Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
162
+ # assumes shape is [num_layers, num_heads, seq_len, hidden_size]
163
+ self.past_key_values_kwargs["past_key_values"] = [
164
+ [kv[:, :, :seq_min, :] for kv in kvs]
165
+ for kvs in self.past_key_values_kwargs["past_key_values"]
166
+ ]
167
+
168
+ @staticmethod
169
+ def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
170
+ model_kwargs = {
171
+ "token": os.getenv("HF_TOKEN"),
172
+ # "attn_implementation": "flash_attention_2",
173
+ # "trust_remote_code": True,
174
+ }
175
+ quantize = kwargs.get("quantize", False)
176
+ if quantize == "int4":
177
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
178
+ load_in_4bit=True,
179
+ )
180
+ elif quantize == "int8":
181
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
182
+ load_in_8bit=True,
183
+ )
184
+ else:
185
+ model_kwargs["torch_dtype"] = torch.bfloat16
186
+
187
+ downloaded_models = get_downloaded_models()
188
+ if model_name in downloaded_models:
189
+ model_kwargs["local_files_only"] = True
190
+ model_name = downloaded_models[model_name]
191
+ logger.info(f"Using local model {model_name}")
192
+ if "/plancraft/outputs" in model_name:
193
+ model_kwargs["local_files_only"] = True
194
+ logger.info(f"Using local model {model_name}")
195
+
196
+ return model_name, model_kwargs
197
+
198
+ def reset(self):
199
+ # NOTE: past_key_values cache with a rolling window
200
+ # is not maximally useful as the beggining shifts over time
201
+ # and therefore cache is invalidated
202
+ self.past_key_values_kwargs = {}
203
+ self.past_token_ids = None
204
+
205
+ def prepare_messages(
206
+ self,
207
+ history: History,
208
+ max_messages_window: int,
209
+ system_prompt: dict = None,
210
+ prompt_images: list = [],
211
+ ) -> tuple[list[dict], list]:
212
+ """
213
+ Prepare the messages using a history
214
+ """
215
+ message_window = history.dialogue_history[-max_messages_window:]
216
+ # remove the first assistant message if it is present
217
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
218
+ message_window = message_window[1:]
219
+ # add the system prompt if the first message is not a system message
220
+ if message_window[0]["role"] != "system" and system_prompt is not None:
221
+ message_window = [system_prompt] + message_window
222
+
223
+ image_window = []
224
+ if self.use_images:
225
+ image_list = prompt_images + history.images
226
+ image_count = 0
227
+ # iterate through the messages in reverse order to assign images
228
+ for m in message_window:
229
+ for content in m["content"]:
230
+ if content["type"] == "image":
231
+ image_count += 1
232
+ assert image_count <= len(image_list), "Too many images"
233
+ image_window = image_list[-image_count:]
234
+ image_window = [Image.fromarray(img) for img in image_window]
235
+
236
+ return message_window, image_window
237
+
238
+ @torch.inference_mode()
239
+ def generate_unconstrained(
240
+ self,
241
+ batch_messages: list[list[dict]],
242
+ start_messages_generation: str = "",
243
+ max_tokens: int = 256,
244
+ temperature=0.6,
245
+ **kwargs,
246
+ ) -> tuple[list[str], int]:
247
+ """
248
+ Generate unconstrained text based on the batch of messages.
249
+ """
250
+ if self.use_images:
251
+ assert "images" in kwargs, "Images required for multimodal model"
252
+
253
+ tokenized_messages = tokenize(
254
+ self.model,
255
+ self.tokenizer,
256
+ batch_messages,
257
+ start_messages_generation=[start_messages_generation] * len(batch_messages),
258
+ max_tokens=max_tokens,
259
+ images=kwargs.get("images") if self.use_images else None,
260
+ )
261
+ prompt_tokens = tokenized_messages["input_ids"].shape[-1]
262
+
263
+ # Sent to the same device as model
264
+ tokenized_messages = {
265
+ k: v.to(self.model.device) for k, v in tokenized_messages.items()
266
+ }
267
+
268
+ # Truncate the key-value cache
269
+ self.truncate_kv_cache(tokenized_messages["input_ids"])
270
+
271
+ if (
272
+ "past_key_values" in self.past_key_values_kwargs
273
+ and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
274
+ > tokenized_messages["input_ids"].shape[-1]
275
+ ):
276
+ raise ValueError("Past key values are larger than the input_ids")
277
+
278
+ past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
279
+ if past_key_values is not None:
280
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
281
+
282
+ generated_sequences = self.model.generate(
283
+ **tokenized_messages,
284
+ do_sample=True,
285
+ temperature=temperature,
286
+ max_new_tokens=max_tokens,
287
+ pad_token_id=self.pad_token_id,
288
+ return_dict_in_generate=True,
289
+ use_cache=True,
290
+ past_key_values=past_key_values,
291
+ return_legacy_cache=True,
292
+ )
293
+ # Cache the past key values
294
+ if self.use_hot_cache:
295
+ self.past_key_values_kwargs["past_key_values"] = (
296
+ generated_sequences.past_key_values
297
+ )
298
+ self.past_token_ids = generated_sequences.sequences
299
+
300
+ # Decode the output
301
+ text_responses = self.tokenizer.batch_decode(
302
+ generated_sequences.sequences[:, prompt_tokens:],
303
+ skip_special_tokens=False,
304
+ )
305
+
306
+ text_responses = [
307
+ text_response.replace("<|eot_id|>", "") for text_response in text_responses
308
+ ]
309
+
310
+ _, total_tokens_used = generated_sequences.sequences.shape
311
+ return text_responses, total_tokens_used
312
+
313
+
314
+ class OpenAIGenerator:
315
+ def __init__(self, use_images=False, model_name="gpt-4o-mini"):
316
+ self.client = OpenAI()
317
+ self.use_images = use_images
318
+ self.model_name = model_name
319
+
320
+ def reset(self):
321
+ pass
322
+
323
+ def prepare_messages(
324
+ self,
325
+ history: History,
326
+ max_messages_window: int,
327
+ system_prompt: dict = None,
328
+ prompt_images: list = [],
329
+ ) -> tuple[list[dict], list]:
330
+ """
331
+ Prepare the image messages for the model
332
+ """
333
+ message_window = history.dialogue_history[-max_messages_window:]
334
+ # remove the first assistant message if it is present
335
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
336
+ message_window = message_window[1:]
337
+ # add the system prompt if the first message is not a system message
338
+ if message_window[0]["role"] != "system" and system_prompt is not None:
339
+ message_window = [system_prompt] + message_window
340
+
341
+ if self.use_images:
342
+ image_list = prompt_images + history.images
343
+ img_idx = -1
344
+ seen_images = 0
345
+ # iterate through the messages in reverse order to assign images
346
+ for i in range(len(message_window) - 1, -1, -1):
347
+ new_content_list = []
348
+ for content in message_window[i]["content"]:
349
+ if content["type"] == "text":
350
+ new_content_list.append(content)
351
+ elif content["type"] == "image":
352
+ base64_image = numpy_to_base64(image_list[img_idx])
353
+ img_idx -= 1
354
+ seen_images + 1
355
+ new_content = {
356
+ "type": "image_url",
357
+ "image_url": {
358
+ "url": f"data:image/jpeg;base64,{base64_image}"
359
+ },
360
+ }
361
+ new_content_list.append(new_content)
362
+ message_window[i]["content"] = new_content_list
363
+ assert seen_images <= len(image_list), "Too many images"
364
+
365
+ return message_window, []
366
+
367
+ def generate_unconstrained(
368
+ self,
369
+ batch_messages: list[list[dict]],
370
+ max_tokens=256,
371
+ **kwargs,
372
+ ) -> tuple[list[str], int]:
373
+ contents = []
374
+ tokens_used = 0
375
+ for messages in batch_messages:
376
+ response = self.client.chat.completions.create(
377
+ model=self.model_name,
378
+ messages=messages,
379
+ temperature=0.0,
380
+ max_tokens=max_tokens,
381
+ top_p=1,
382
+ frequency_penalty=0,
383
+ presence_penalty=0,
384
+ stop=["\n", "\n\n"],
385
+ )
386
+ content = response.choices[0].message.content
387
+ tokens_used += response.usage.total_tokens
388
+ contents.append(content)
389
+ return contents, tokens_used
390
+
391
+
392
+ class OAMGenerator:
393
+ def __init__(
394
+ self,
395
+ model_name,
396
+ ):
397
+ self.model_name = model_name
398
+ logger.info("Loading model")
399
+ time_now = time.time()
400
+ self.model = PlancraftOAM.from_pretrained(model_name)
401
+ self.model.cuda()
402
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
403
+ # compile
404
+ time_now = time.time()
405
+ self.model = torch.compile(self.model)
406
+ logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
407
+
408
+ self.model.eval()
409
+ self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
410
+ self.model.tokenizer.truncation_side = "left"
411
+
412
+ def reset(self):
413
+ pass
414
+
415
+ def prepare_messages(
416
+ self,
417
+ history: History,
418
+ max_messages_window: int,
419
+ system_prompt: dict = None,
420
+ prompt_images: list = [],
421
+ ) -> tuple[list[dict], list]:
422
+ """
423
+ Prepare the messages using a history
424
+ """
425
+ message_window = history.dialogue_history[-max_messages_window:]
426
+ # remove the first assistant message if it is present
427
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
428
+ message_window = message_window[1:]
429
+ # add the system prompt if the first message is not a system message
430
+ if message_window[0]["role"] != "system" and system_prompt is not None:
431
+ message_window = [system_prompt] + message_window
432
+
433
+ image_window = []
434
+
435
+ image_list = prompt_images + history.images
436
+ image_count = 0
437
+
438
+ # iterate through the messages to count how many images are present
439
+ new_message_window = []
440
+ for m in message_window:
441
+ message_content = m["content"]
442
+ message_role = m["role"]
443
+
444
+ if "\ninventory:\n" in message_content:
445
+ message_content = (
446
+ message_content.split("\ninventory:\n")[0]
447
+ + "\ninventory:<|inventory|>"
448
+ )
449
+
450
+ if "<|inventory|>" in message_content:
451
+ image_count += 1
452
+
453
+ new_message_window.append(
454
+ {"role": message_role, "content": message_content}
455
+ )
456
+
457
+ assert image_count <= len(image_list), "Too many images"
458
+ # get messages from end of queue
459
+ image_window = image_list[-image_count:]
460
+ image_window = [Image.fromarray(img) for img in image_window]
461
+
462
+ return new_message_window, image_window
463
+
464
+ @torch.inference_mode()
465
+ def generate_unconstrained(
466
+ self,
467
+ batch_messages: list[list[dict]],
468
+ images: list,
469
+ max_tokens: int = 256,
470
+ temperature=0.6,
471
+ **kwargs,
472
+ ) -> tuple[list[str], int]:
473
+ """
474
+ Generate unconstrained text based on the batch of messages.
475
+ """
476
+ text_responses, total_tokens_used = self.model.generate(
477
+ batch_messages=batch_messages,
478
+ batch_images=images,
479
+ do_sample=True,
480
+ temperature=temperature,
481
+ max_new_tokens=max_tokens,
482
+ )
483
+ return text_responses, total_tokens_used