plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,16 @@
1
+ import os
2
+ import glob
3
+
4
+ import numpy as np
5
+ import imageio
6
+
7
+
8
+ def get_few_shot_images_path():
9
+ return os.path.dirname(__file__)
10
+
11
+
12
+ def load_prompt_images() -> list[np.ndarray]:
13
+ current_dir = get_few_shot_images_path()
14
+ files = glob.glob(os.path.join(current_dir, "*.png"))
15
+ images = [imageio.imread(file) for file in files]
16
+ return images
@@ -0,0 +1,480 @@
1
+ import os
2
+ import time
3
+
4
+ import torch
5
+ from dotenv import load_dotenv
6
+ from loguru import logger
7
+ from openai import OpenAI
8
+ from PIL import Image
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoModelForVision2Seq,
12
+ AutoProcessor,
13
+ AutoTokenizer,
14
+ BitsAndBytesConfig,
15
+ )
16
+ from transformers.cache_utils import DynamicCache
17
+
18
+ from plancraft.models.base import History
19
+ from plancraft.models.oam import PlancraftOAM
20
+ from plancraft.models.utils import (
21
+ get_downloaded_models,
22
+ numpy_to_base64,
23
+ tokenize,
24
+ )
25
+
26
+ load_dotenv()
27
+
28
+
29
+ class TransformersGenerator:
30
+ def __init__(
31
+ self,
32
+ model_name: str,
33
+ tokenizer_name: str = "same",
34
+ quantize=False,
35
+ use_images=False,
36
+ use_hot_cache=True,
37
+ adapter_name="",
38
+ **kwargs,
39
+ ):
40
+ self.model_name = model_name
41
+ self.use_hot_cache = use_hot_cache
42
+
43
+ if tokenizer_name == "same":
44
+ tokenizer_name = model_name
45
+
46
+ self.use_images = use_images
47
+ model_name, model_kwargs = self.build_model_kwargs(
48
+ model_name, quantize=quantize
49
+ )
50
+ self.processor = None
51
+ if "idefics" in model_name:
52
+ assert use_images, "Idefics model requires multimodal input"
53
+ self.tokenizer = AutoProcessor.from_pretrained(
54
+ tokenizer_name,
55
+ **model_kwargs,
56
+ )
57
+ self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
58
+ logger.info("Loading model")
59
+ time_now = time.time()
60
+ self.model = AutoModelForVision2Seq.from_pretrained(
61
+ model_name,
62
+ device_map="auto",
63
+ **model_kwargs,
64
+ )
65
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
66
+ # set pad_token_id
67
+ if self.tokenizer.tokenizer.pad_token_id:
68
+ self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
69
+ else:
70
+ self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
71
+ else:
72
+ self.tokenizer = AutoTokenizer.from_pretrained(
73
+ tokenizer_name,
74
+ token=os.getenv("HF_TOKEN"), # trust_remote_code=True
75
+ padding_side="left", # ensure that the padding is on the left
76
+ )
77
+ logger.info("Loading model")
78
+ time_now = time.time()
79
+ self.model = AutoModelForCausalLM.from_pretrained(
80
+ model_name,
81
+ device_map="auto",
82
+ **model_kwargs,
83
+ )
84
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
85
+
86
+ # load OA adapter
87
+ if adapter_name != "":
88
+ logger.info(f"Loading adapter and tokenizer from {adapter_name}")
89
+ self.tokenizer = AutoTokenizer.from_pretrained(
90
+ adapter_name,
91
+ padding_side="left",
92
+ )
93
+ self.model.resize_token_embeddings(len(self.tokenizer))
94
+ self.model.load_adapter(adapter_name)
95
+
96
+ # set pad_token_id
97
+ if self.tokenizer.pad_token_id:
98
+ self.pad_token_id = self.tokenizer.pad_token_id
99
+ else:
100
+ self.tokenizer.pad_token = self.tokenizer.eos_token
101
+ self.pad_token_id = self.tokenizer.eos_token_id
102
+
103
+ # compile
104
+ time_now = time.time()
105
+ self.model = torch.compile(self.model)
106
+ logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
107
+
108
+ self.model.eval()
109
+ if self.pad_token_id is None:
110
+ self.tokenizer.pad_token = self.tokenizer.eos_token
111
+ self.model.config.pad_token_id = self.model.config.eos_token_id
112
+ self.tokenizer.truncation_side = "left"
113
+
114
+ self.past_key_values_kwargs = {}
115
+ self.past_token_ids = None
116
+
117
+ def truncate_kv_cache(self, new_token_ids: torch.Tensor):
118
+ """
119
+ Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
120
+ Uses:
121
+ past_ids: torch.Tensor [B, T]
122
+ new_ids: torch.Tensor [B, T]
123
+ kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
124
+
125
+ NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
126
+ """
127
+ if (
128
+ self.past_token_ids is None
129
+ or "past_key_values" not in self.past_key_values_kwargs
130
+ ):
131
+ return
132
+
133
+ # caching doesn't seem to work with multimodal models
134
+ if self.use_images:
135
+ self.past_key_values_kwargs = {}
136
+ return
137
+
138
+ past_batch_size, past_seq_len = self.past_token_ids.shape
139
+ new_batch_size, new_seq_len = new_token_ids.shape
140
+
141
+ # If the batch size has changed, reset the cache
142
+ if past_batch_size != new_batch_size:
143
+ self.past_key_values_kwargs = {}
144
+ return
145
+
146
+ min_shape = min(past_seq_len, new_seq_len)
147
+ compare_past = (
148
+ self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
149
+ )
150
+
151
+ # All tokens are the same - no need to truncate
152
+ if not compare_past.any():
153
+ return
154
+
155
+ # Find the first token that is different between the past and new tokens
156
+ seq_min = torch.argmax(compare_past.double(), dim=1).min()
157
+
158
+ # Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
159
+ # assumes shape is [num_layers, num_heads, seq_len, hidden_size]
160
+ self.past_key_values_kwargs["past_key_values"] = [
161
+ [kv[:, :, :seq_min, :] for kv in kvs]
162
+ for kvs in self.past_key_values_kwargs["past_key_values"]
163
+ ]
164
+
165
+ @staticmethod
166
+ def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
167
+ model_kwargs = {
168
+ "token": os.getenv("HF_TOKEN"),
169
+ # "attn_implementation": "flash_attention_2",
170
+ # "trust_remote_code": True,
171
+ }
172
+ quantize = kwargs.get("quantize", False)
173
+ if quantize == "int4":
174
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
175
+ load_in_4bit=True,
176
+ )
177
+ elif quantize == "int8":
178
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
179
+ load_in_8bit=True,
180
+ )
181
+ else:
182
+ model_kwargs["torch_dtype"] = torch.bfloat16
183
+
184
+ downloaded_models = get_downloaded_models()
185
+ if model_name in downloaded_models:
186
+ model_kwargs["local_files_only"] = True
187
+ model_name = downloaded_models[model_name]
188
+ logger.info(f"Using local model {model_name}")
189
+ if "/plancraft/outputs" in model_name:
190
+ model_kwargs["local_files_only"] = True
191
+ logger.info(f"Using local model {model_name}")
192
+
193
+ return model_name, model_kwargs
194
+
195
+ def reset(self):
196
+ # NOTE: past_key_values cache with a rolling window
197
+ # is not maximally useful as the beggining shifts over time
198
+ # and therefore cache is invalidated
199
+ self.past_key_values_kwargs = {}
200
+ self.past_token_ids = None
201
+
202
+ def prepare_messages(
203
+ self,
204
+ history: History,
205
+ max_messages_window: int,
206
+ system_prompt: dict = None,
207
+ prompt_images: list = [],
208
+ ) -> tuple[list[dict], list]:
209
+ """
210
+ Prepare the messages using a history
211
+ """
212
+ message_window = history.dialogue_history[-max_messages_window:]
213
+ # remove the first assistant message if it is present
214
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
215
+ message_window = message_window[1:]
216
+ # add the system prompt if the first message is not a system message
217
+ if message_window[0]["role"] != "system" and system_prompt is not None:
218
+ message_window = [system_prompt] + message_window
219
+
220
+ image_window = []
221
+ if self.use_images:
222
+ image_list = prompt_images + history.images
223
+ image_count = 0
224
+ # iterate through the messages in reverse order to assign images
225
+ for m in message_window:
226
+ for content in m["content"]:
227
+ if content["type"] == "image":
228
+ image_count += 1
229
+ assert image_count <= len(image_list), "Too many images"
230
+ image_window = image_list[-image_count:]
231
+ image_window = [Image.fromarray(img) for img in image_window]
232
+
233
+ return message_window, image_window
234
+
235
+ @torch.inference_mode()
236
+ def generate_unconstrained(
237
+ self,
238
+ batch_messages: list[list[dict]],
239
+ start_messages_generation: str = "",
240
+ max_tokens: int = 256,
241
+ temperature=0.6,
242
+ **kwargs,
243
+ ) -> tuple[list[str], int]:
244
+ """
245
+ Generate unconstrained text based on the batch of messages.
246
+ """
247
+ if self.use_images:
248
+ assert "images" in kwargs, "Images required for multimodal model"
249
+
250
+ tokenized_messages = tokenize(
251
+ self.model,
252
+ self.tokenizer,
253
+ batch_messages,
254
+ start_messages_generation=[start_messages_generation] * len(batch_messages),
255
+ max_tokens=max_tokens,
256
+ images=kwargs.get("images") if self.use_images else None,
257
+ )
258
+ prompt_tokens = tokenized_messages["input_ids"].shape[-1]
259
+
260
+ # Sent to the same device as model
261
+ tokenized_messages = {
262
+ k: v.to(self.model.device) for k, v in tokenized_messages.items()
263
+ }
264
+
265
+ # Truncate the key-value cache
266
+ self.truncate_kv_cache(tokenized_messages["input_ids"])
267
+
268
+ if (
269
+ "past_key_values" in self.past_key_values_kwargs
270
+ and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
271
+ > tokenized_messages["input_ids"].shape[-1]
272
+ ):
273
+ raise ValueError("Past key values are larger than the input_ids")
274
+
275
+ past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
276
+ if past_key_values is not None:
277
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
278
+
279
+ generated_sequences = self.model.generate(
280
+ **tokenized_messages,
281
+ do_sample=True,
282
+ temperature=temperature,
283
+ max_new_tokens=max_tokens,
284
+ pad_token_id=self.pad_token_id,
285
+ return_dict_in_generate=True,
286
+ use_cache=True,
287
+ past_key_values=past_key_values,
288
+ return_legacy_cache=True,
289
+ )
290
+ # Cache the past key values
291
+ if self.use_hot_cache:
292
+ self.past_key_values_kwargs["past_key_values"] = (
293
+ generated_sequences.past_key_values
294
+ )
295
+ self.past_token_ids = generated_sequences.sequences
296
+
297
+ # Decode the output
298
+ text_responses = self.tokenizer.batch_decode(
299
+ generated_sequences.sequences[:, prompt_tokens:],
300
+ skip_special_tokens=False,
301
+ )
302
+
303
+ text_responses = [
304
+ text_response.replace("<|eot_id|>", "") for text_response in text_responses
305
+ ]
306
+
307
+ _, total_tokens_used = generated_sequences.sequences.shape
308
+ return text_responses, total_tokens_used
309
+
310
+
311
+ class OpenAIGenerator:
312
+ def __init__(self, use_images=False, model_name="gpt-4o-mini"):
313
+ self.client = OpenAI()
314
+ self.use_images = use_images
315
+ self.model_name = model_name
316
+
317
+ def reset(self):
318
+ pass
319
+
320
+ def prepare_messages(
321
+ self,
322
+ history: History,
323
+ max_messages_window: int,
324
+ system_prompt: dict = None,
325
+ prompt_images: list = [],
326
+ ) -> tuple[list[dict], list]:
327
+ """
328
+ Prepare the image messages for the model
329
+ """
330
+ message_window = history.dialogue_history[-max_messages_window:]
331
+ # remove the first assistant message if it is present
332
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
333
+ message_window = message_window[1:]
334
+ # add the system prompt if the first message is not a system message
335
+ if message_window[0]["role"] != "system" and system_prompt is not None:
336
+ message_window = [system_prompt] + message_window
337
+
338
+ if self.use_images:
339
+ image_list = prompt_images + history.images
340
+ img_idx = -1
341
+ seen_images = 0
342
+ # iterate through the messages in reverse order to assign images
343
+ for i in range(len(message_window) - 1, -1, -1):
344
+ new_content_list = []
345
+ for content in message_window[i]["content"]:
346
+ if content["type"] == "text":
347
+ new_content_list.append(content)
348
+ elif content["type"] == "image":
349
+ base64_image = numpy_to_base64(image_list[img_idx])
350
+ img_idx -= 1
351
+ seen_images + 1
352
+ new_content = {
353
+ "type": "image_url",
354
+ "image_url": {
355
+ "url": f"data:image/jpeg;base64,{base64_image}"
356
+ },
357
+ }
358
+ new_content_list.append(new_content)
359
+ message_window[i]["content"] = new_content_list
360
+ assert seen_images <= len(image_list), "Too many images"
361
+
362
+ return message_window, []
363
+
364
+ def generate_unconstrained(
365
+ self,
366
+ batch_messages: list[list[dict]],
367
+ max_tokens=256,
368
+ **kwargs,
369
+ ) -> tuple[list[str], int]:
370
+ contents = []
371
+ tokens_used = 0
372
+ for messages in batch_messages:
373
+ response = self.client.chat.completions.create(
374
+ model=self.model_name,
375
+ messages=messages,
376
+ temperature=0.0,
377
+ max_tokens=max_tokens,
378
+ top_p=1,
379
+ frequency_penalty=0,
380
+ presence_penalty=0,
381
+ stop=["\n", "\n\n"],
382
+ )
383
+ content = response.choices[0].message.content
384
+ tokens_used += response.usage.total_tokens
385
+ contents.append(content)
386
+ return contents, tokens_used
387
+
388
+
389
+ class OAMGenerator:
390
+ def __init__(
391
+ self,
392
+ model_name,
393
+ ):
394
+ self.model_name = model_name
395
+ logger.info("Loading model")
396
+ time_now = time.time()
397
+ self.model = PlancraftOAM.from_pretrained(model_name)
398
+ self.model.cuda()
399
+ logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
400
+ # compile
401
+ time_now = time.time()
402
+ self.model = torch.compile(self.model)
403
+ logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
404
+
405
+ self.model.eval()
406
+ self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
407
+ self.model.tokenizer.truncation_side = "left"
408
+
409
+ def reset(self):
410
+ pass
411
+
412
+ def prepare_messages(
413
+ self,
414
+ history: History,
415
+ max_messages_window: int,
416
+ system_prompt: dict = None,
417
+ prompt_images: list = [],
418
+ ) -> tuple[list[dict], list]:
419
+ """
420
+ Prepare the messages using a history
421
+ """
422
+ message_window = history.dialogue_history[-max_messages_window:]
423
+ # remove the first assistant message if it is present
424
+ if len(message_window) > 0 and message_window[0]["role"] == "assistant":
425
+ message_window = message_window[1:]
426
+ # add the system prompt if the first message is not a system message
427
+ if message_window[0]["role"] != "system" and system_prompt is not None:
428
+ message_window = [system_prompt] + message_window
429
+
430
+ image_window = []
431
+
432
+ image_list = prompt_images + history.images
433
+ image_count = 0
434
+
435
+ # iterate through the messages to count how many images are present
436
+ new_message_window = []
437
+ for m in message_window:
438
+ message_content = m["content"]
439
+ message_role = m["role"]
440
+
441
+ if "\ninventory:\n" in message_content:
442
+ message_content = (
443
+ message_content.split("\ninventory:\n")[0]
444
+ + "\ninventory:<|inventory|>"
445
+ )
446
+
447
+ if "<|inventory|>" in message_content:
448
+ image_count += 1
449
+
450
+ new_message_window.append(
451
+ {"role": message_role, "content": message_content}
452
+ )
453
+
454
+ assert image_count <= len(image_list), "Too many images"
455
+ # get messages from end of queue
456
+ image_window = image_list[-image_count:]
457
+ image_window = [Image.fromarray(img) for img in image_window]
458
+
459
+ return new_message_window, image_window
460
+
461
+ @torch.inference_mode()
462
+ def generate_unconstrained(
463
+ self,
464
+ batch_messages: list[list[dict]],
465
+ images: list,
466
+ max_tokens: int = 256,
467
+ temperature=0.6,
468
+ **kwargs,
469
+ ) -> tuple[list[str], int]:
470
+ """
471
+ Generate unconstrained text based on the batch of messages.
472
+ """
473
+ text_responses, total_tokens_used = self.model.generate(
474
+ batch_messages=batch_messages,
475
+ batch_images=images,
476
+ do_sample=True,
477
+ temperature=temperature,
478
+ max_new_tokens=max_tokens,
479
+ )
480
+ return text_responses, total_tokens_used