plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/generators.py DELETED
@@ -1,480 +0,0 @@
1
- import os
2
- import time
3
-
4
- import torch
5
- from dotenv import load_dotenv
6
- from loguru import logger
7
- from openai import OpenAI
8
- from PIL import Image
9
- from transformers import (
10
- AutoModelForCausalLM,
11
- AutoModelForVision2Seq,
12
- AutoProcessor,
13
- AutoTokenizer,
14
- BitsAndBytesConfig,
15
- )
16
- from transformers.cache_utils import DynamicCache
17
-
18
- from plancraft.models.base import History
19
- from plancraft.models.oam import PlancraftOAM
20
- from plancraft.models.utils import (
21
- get_downloaded_models,
22
- numpy_to_base64,
23
- tokenize,
24
- )
25
-
26
- load_dotenv()
27
-
28
-
29
- class TransformersGenerator:
30
- def __init__(
31
- self,
32
- model_name: str,
33
- tokenizer_name: str = "same",
34
- quantize=False,
35
- use_images=False,
36
- use_hot_cache=True,
37
- adapter_name="",
38
- **kwargs,
39
- ):
40
- self.model_name = model_name
41
- self.use_hot_cache = use_hot_cache
42
-
43
- if tokenizer_name == "same":
44
- tokenizer_name = model_name
45
-
46
- self.use_images = use_images
47
- model_name, model_kwargs = self.build_model_kwargs(
48
- model_name, quantize=quantize
49
- )
50
- self.processor = None
51
- if "idefics" in model_name:
52
- assert use_images, "Idefics model requires multimodal input"
53
- self.tokenizer = AutoProcessor.from_pretrained(
54
- tokenizer_name,
55
- **model_kwargs,
56
- )
57
- self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
58
- logger.info("Loading model")
59
- time_now = time.time()
60
- self.model = AutoModelForVision2Seq.from_pretrained(
61
- model_name,
62
- device_map="auto",
63
- **model_kwargs,
64
- )
65
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
66
- # set pad_token_id
67
- if self.tokenizer.tokenizer.pad_token_id:
68
- self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
69
- else:
70
- self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
71
- else:
72
- self.tokenizer = AutoTokenizer.from_pretrained(
73
- tokenizer_name,
74
- token=os.getenv("HF_TOKEN"), # trust_remote_code=True
75
- padding_side="left", # ensure that the padding is on the left
76
- )
77
- logger.info("Loading model")
78
- time_now = time.time()
79
- self.model = AutoModelForCausalLM.from_pretrained(
80
- model_name,
81
- device_map="auto",
82
- **model_kwargs,
83
- )
84
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
85
-
86
- # load OA adapter
87
- if adapter_name != "":
88
- logger.info(f"Loading adapter and tokenizer from {adapter_name}")
89
- self.tokenizer = AutoTokenizer.from_pretrained(
90
- adapter_name,
91
- padding_side="left",
92
- )
93
- self.model.resize_token_embeddings(len(self.tokenizer))
94
- self.model.load_adapter(adapter_name)
95
-
96
- # set pad_token_id
97
- if self.tokenizer.pad_token_id:
98
- self.pad_token_id = self.tokenizer.pad_token_id
99
- else:
100
- self.tokenizer.pad_token = self.tokenizer.eos_token
101
- self.pad_token_id = self.tokenizer.eos_token_id
102
-
103
- # compile
104
- time_now = time.time()
105
- self.model = torch.compile(self.model)
106
- logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
107
-
108
- self.model.eval()
109
- if self.pad_token_id is None:
110
- self.tokenizer.pad_token = self.tokenizer.eos_token
111
- self.model.config.pad_token_id = self.model.config.eos_token_id
112
- self.tokenizer.truncation_side = "left"
113
-
114
- self.past_key_values_kwargs = {}
115
- self.past_token_ids = None
116
-
117
- def truncate_kv_cache(self, new_token_ids: torch.Tensor):
118
- """
119
- Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
120
- Uses:
121
- past_ids: torch.Tensor [B, T]
122
- new_ids: torch.Tensor [B, T]
123
- kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
124
-
125
- NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
126
- """
127
- if (
128
- self.past_token_ids is None
129
- or "past_key_values" not in self.past_key_values_kwargs
130
- ):
131
- return
132
-
133
- # caching doesn't seem to work with multimodal models
134
- if self.use_images:
135
- self.past_key_values_kwargs = {}
136
- return
137
-
138
- past_batch_size, past_seq_len = self.past_token_ids.shape
139
- new_batch_size, new_seq_len = new_token_ids.shape
140
-
141
- # If the batch size has changed, reset the cache
142
- if past_batch_size != new_batch_size:
143
- self.past_key_values_kwargs = {}
144
- return
145
-
146
- min_shape = min(past_seq_len, new_seq_len)
147
- compare_past = (
148
- self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
149
- )
150
-
151
- # All tokens are the same - no need to truncate
152
- if not compare_past.any():
153
- return
154
-
155
- # Find the first token that is different between the past and new tokens
156
- seq_min = torch.argmax(compare_past.double(), dim=1).min()
157
-
158
- # Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
159
- # assumes shape is [num_layers, num_heads, seq_len, hidden_size]
160
- self.past_key_values_kwargs["past_key_values"] = [
161
- [kv[:, :, :seq_min, :] for kv in kvs]
162
- for kvs in self.past_key_values_kwargs["past_key_values"]
163
- ]
164
-
165
- @staticmethod
166
- def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
167
- model_kwargs = {
168
- "token": os.getenv("HF_TOKEN"),
169
- # "attn_implementation": "flash_attention_2",
170
- # "trust_remote_code": True,
171
- }
172
- quantize = kwargs.get("quantize", False)
173
- if quantize == "int4":
174
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
175
- load_in_4bit=True,
176
- )
177
- elif quantize == "int8":
178
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
179
- load_in_8bit=True,
180
- )
181
- else:
182
- model_kwargs["torch_dtype"] = torch.bfloat16
183
-
184
- downloaded_models = get_downloaded_models()
185
- if model_name in downloaded_models:
186
- model_kwargs["local_files_only"] = True
187
- model_name = downloaded_models[model_name]
188
- logger.info(f"Using local model {model_name}")
189
- if "/plancraft/outputs" in model_name:
190
- model_kwargs["local_files_only"] = True
191
- logger.info(f"Using local model {model_name}")
192
-
193
- return model_name, model_kwargs
194
-
195
- def reset(self):
196
- # NOTE: past_key_values cache with a rolling window
197
- # is not maximally useful as the beggining shifts over time
198
- # and therefore cache is invalidated
199
- self.past_key_values_kwargs = {}
200
- self.past_token_ids = None
201
-
202
- def prepare_messages(
203
- self,
204
- history: History,
205
- max_messages_window: int,
206
- system_prompt: dict = None,
207
- prompt_images: list = [],
208
- ) -> tuple[list[dict], list]:
209
- """
210
- Prepare the messages using a history
211
- """
212
- message_window = history.dialogue_history[-max_messages_window:]
213
- # remove the first assistant message if it is present
214
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
215
- message_window = message_window[1:]
216
- # add the system prompt if the first message is not a system message
217
- if message_window[0]["role"] != "system" and system_prompt is not None:
218
- message_window = [system_prompt] + message_window
219
-
220
- image_window = []
221
- if self.use_images:
222
- image_list = prompt_images + history.images
223
- image_count = 0
224
- # iterate through the messages in reverse order to assign images
225
- for m in message_window:
226
- for content in m["content"]:
227
- if content["type"] == "image":
228
- image_count += 1
229
- assert image_count <= len(image_list), "Too many images"
230
- image_window = image_list[-image_count:]
231
- image_window = [Image.fromarray(img) for img in image_window]
232
-
233
- return message_window, image_window
234
-
235
- @torch.inference_mode()
236
- def generate_unconstrained(
237
- self,
238
- batch_messages: list[list[dict]],
239
- start_messages_generation: str = "",
240
- max_tokens: int = 256,
241
- temperature=0.6,
242
- **kwargs,
243
- ) -> tuple[list[str], int]:
244
- """
245
- Generate unconstrained text based on the batch of messages.
246
- """
247
- if self.use_images:
248
- assert "images" in kwargs, "Images required for multimodal model"
249
-
250
- tokenized_messages = tokenize(
251
- self.model,
252
- self.tokenizer,
253
- batch_messages,
254
- start_messages_generation=[start_messages_generation] * len(batch_messages),
255
- max_tokens=max_tokens,
256
- images=kwargs.get("images") if self.use_images else None,
257
- )
258
- prompt_tokens = tokenized_messages["input_ids"].shape[-1]
259
-
260
- # Sent to the same device as model
261
- tokenized_messages = {
262
- k: v.to(self.model.device) for k, v in tokenized_messages.items()
263
- }
264
-
265
- # Truncate the key-value cache
266
- self.truncate_kv_cache(tokenized_messages["input_ids"])
267
-
268
- if (
269
- "past_key_values" in self.past_key_values_kwargs
270
- and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
271
- > tokenized_messages["input_ids"].shape[-1]
272
- ):
273
- raise ValueError("Past key values are larger than the input_ids")
274
-
275
- past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
276
- if past_key_values is not None:
277
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
278
-
279
- generated_sequences = self.model.generate(
280
- **tokenized_messages,
281
- do_sample=True,
282
- temperature=temperature,
283
- max_new_tokens=max_tokens,
284
- pad_token_id=self.pad_token_id,
285
- return_dict_in_generate=True,
286
- use_cache=True,
287
- past_key_values=past_key_values,
288
- return_legacy_cache=True,
289
- )
290
- # Cache the past key values
291
- if self.use_hot_cache:
292
- self.past_key_values_kwargs["past_key_values"] = (
293
- generated_sequences.past_key_values
294
- )
295
- self.past_token_ids = generated_sequences.sequences
296
-
297
- # Decode the output
298
- text_responses = self.tokenizer.batch_decode(
299
- generated_sequences.sequences[:, prompt_tokens:],
300
- skip_special_tokens=False,
301
- )
302
-
303
- text_responses = [
304
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
305
- ]
306
-
307
- _, total_tokens_used = generated_sequences.sequences.shape
308
- return text_responses, total_tokens_used
309
-
310
-
311
- class OpenAIGenerator:
312
- def __init__(self, use_images=False, model_name="gpt-4o-mini"):
313
- self.client = OpenAI()
314
- self.use_images = use_images
315
- self.model_name = model_name
316
-
317
- def reset(self):
318
- pass
319
-
320
- def prepare_messages(
321
- self,
322
- history: History,
323
- max_messages_window: int,
324
- system_prompt: dict = None,
325
- prompt_images: list = [],
326
- ) -> tuple[list[dict], list]:
327
- """
328
- Prepare the image messages for the model
329
- """
330
- message_window = history.dialogue_history[-max_messages_window:]
331
- # remove the first assistant message if it is present
332
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
333
- message_window = message_window[1:]
334
- # add the system prompt if the first message is not a system message
335
- if message_window[0]["role"] != "system" and system_prompt is not None:
336
- message_window = [system_prompt] + message_window
337
-
338
- if self.use_images:
339
- image_list = prompt_images + history.images
340
- img_idx = -1
341
- seen_images = 0
342
- # iterate through the messages in reverse order to assign images
343
- for i in range(len(message_window) - 1, -1, -1):
344
- new_content_list = []
345
- for content in message_window[i]["content"]:
346
- if content["type"] == "text":
347
- new_content_list.append(content)
348
- elif content["type"] == "image":
349
- base64_image = numpy_to_base64(image_list[img_idx])
350
- img_idx -= 1
351
- seen_images + 1
352
- new_content = {
353
- "type": "image_url",
354
- "image_url": {
355
- "url": f"data:image/jpeg;base64,{base64_image}"
356
- },
357
- }
358
- new_content_list.append(new_content)
359
- message_window[i]["content"] = new_content_list
360
- assert seen_images <= len(image_list), "Too many images"
361
-
362
- return message_window, []
363
-
364
- def generate_unconstrained(
365
- self,
366
- batch_messages: list[list[dict]],
367
- max_tokens=256,
368
- **kwargs,
369
- ) -> tuple[list[str], int]:
370
- contents = []
371
- tokens_used = 0
372
- for messages in batch_messages:
373
- response = self.client.chat.completions.create(
374
- model=self.model_name,
375
- messages=messages,
376
- temperature=0.0,
377
- max_tokens=max_tokens,
378
- top_p=1,
379
- frequency_penalty=0,
380
- presence_penalty=0,
381
- stop=["\n", "\n\n"],
382
- )
383
- content = response.choices[0].message.content
384
- tokens_used += response.usage.total_tokens
385
- contents.append(content)
386
- return contents, tokens_used
387
-
388
-
389
- class OAMGenerator:
390
- def __init__(
391
- self,
392
- model_name,
393
- ):
394
- self.model_name = model_name
395
- logger.info("Loading model")
396
- time_now = time.time()
397
- self.model = PlancraftOAM.from_pretrained(model_name)
398
- self.model.cuda()
399
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
400
- # compile
401
- time_now = time.time()
402
- self.model = torch.compile(self.model)
403
- logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
404
-
405
- self.model.eval()
406
- self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
407
- self.model.tokenizer.truncation_side = "left"
408
-
409
- def reset(self):
410
- pass
411
-
412
- def prepare_messages(
413
- self,
414
- history: History,
415
- max_messages_window: int,
416
- system_prompt: dict = None,
417
- prompt_images: list = [],
418
- ) -> tuple[list[dict], list]:
419
- """
420
- Prepare the messages using a history
421
- """
422
- message_window = history.dialogue_history[-max_messages_window:]
423
- # remove the first assistant message if it is present
424
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
425
- message_window = message_window[1:]
426
- # add the system prompt if the first message is not a system message
427
- if message_window[0]["role"] != "system" and system_prompt is not None:
428
- message_window = [system_prompt] + message_window
429
-
430
- image_window = []
431
-
432
- image_list = prompt_images + history.images
433
- image_count = 0
434
-
435
- # iterate through the messages to count how many images are present
436
- new_message_window = []
437
- for m in message_window:
438
- message_content = m["content"]
439
- message_role = m["role"]
440
-
441
- if "\ninventory:\n" in message_content:
442
- message_content = (
443
- message_content.split("\ninventory:\n")[0]
444
- + "\ninventory:<|inventory|>"
445
- )
446
-
447
- if "<|inventory|>" in message_content:
448
- image_count += 1
449
-
450
- new_message_window.append(
451
- {"role": message_role, "content": message_content}
452
- )
453
-
454
- assert image_count <= len(image_list), "Too many images"
455
- # get messages from end of queue
456
- image_window = image_list[-image_count:]
457
- image_window = [Image.fromarray(img) for img in image_window]
458
-
459
- return new_message_window, image_window
460
-
461
- @torch.inference_mode()
462
- def generate_unconstrained(
463
- self,
464
- batch_messages: list[list[dict]],
465
- images: list,
466
- max_tokens: int = 256,
467
- temperature=0.6,
468
- **kwargs,
469
- ) -> tuple[list[str], int]:
470
- """
471
- Generate unconstrained text based on the batch of messages.
472
- """
473
- text_responses, total_tokens_used = self.model.generate(
474
- batch_messages=batch_messages,
475
- batch_images=images,
476
- do_sample=True,
477
- temperature=temperature,
478
- max_new_tokens=max_tokens,
479
- )
480
- return text_responses, total_tokens_used