plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
models/generators.py DELETED
@@ -1,480 +0,0 @@
1
- import os
2
- import time
3
-
4
- import torch
5
- from dotenv import load_dotenv
6
- from loguru import logger
7
- from openai import OpenAI
8
- from PIL import Image
9
- from transformers import (
10
- AutoModelForCausalLM,
11
- AutoModelForVision2Seq,
12
- AutoProcessor,
13
- AutoTokenizer,
14
- BitsAndBytesConfig,
15
- )
16
- from transformers.cache_utils import DynamicCache
17
-
18
- from plancraft.models.base import History
19
- from plancraft.models.oam import PlancraftOAM
20
- from plancraft.models.utils import (
21
- get_downloaded_models,
22
- numpy_to_base64,
23
- tokenize,
24
- )
25
-
26
- load_dotenv()
27
-
28
-
29
- class TransformersGenerator:
30
- def __init__(
31
- self,
32
- model_name: str,
33
- tokenizer_name: str = "same",
34
- quantize=False,
35
- use_images=False,
36
- use_hot_cache=True,
37
- adapter_name="",
38
- **kwargs,
39
- ):
40
- self.model_name = model_name
41
- self.use_hot_cache = use_hot_cache
42
-
43
- if tokenizer_name == "same":
44
- tokenizer_name = model_name
45
-
46
- self.use_images = use_images
47
- model_name, model_kwargs = self.build_model_kwargs(
48
- model_name, quantize=quantize
49
- )
50
- self.processor = None
51
- if "idefics" in model_name:
52
- assert use_images, "Idefics model requires multimodal input"
53
- self.tokenizer = AutoProcessor.from_pretrained(
54
- tokenizer_name,
55
- **model_kwargs,
56
- )
57
- self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
58
- logger.info("Loading model")
59
- time_now = time.time()
60
- self.model = AutoModelForVision2Seq.from_pretrained(
61
- model_name,
62
- device_map="auto",
63
- **model_kwargs,
64
- )
65
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
66
- # set pad_token_id
67
- if self.tokenizer.tokenizer.pad_token_id:
68
- self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
69
- else:
70
- self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
71
- else:
72
- self.tokenizer = AutoTokenizer.from_pretrained(
73
- tokenizer_name,
74
- token=os.getenv("HF_TOKEN"), # trust_remote_code=True
75
- padding_side="left", # ensure that the padding is on the left
76
- )
77
- logger.info("Loading model")
78
- time_now = time.time()
79
- self.model = AutoModelForCausalLM.from_pretrained(
80
- model_name,
81
- device_map="auto",
82
- **model_kwargs,
83
- )
84
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
85
-
86
- # load OA adapter
87
- if adapter_name != "":
88
- logger.info(f"Loading adapter and tokenizer from {adapter_name}")
89
- self.tokenizer = AutoTokenizer.from_pretrained(
90
- adapter_name,
91
- padding_side="left",
92
- )
93
- self.model.resize_token_embeddings(len(self.tokenizer))
94
- self.model.load_adapter(adapter_name)
95
-
96
- # set pad_token_id
97
- if self.tokenizer.pad_token_id:
98
- self.pad_token_id = self.tokenizer.pad_token_id
99
- else:
100
- self.tokenizer.pad_token = self.tokenizer.eos_token
101
- self.pad_token_id = self.tokenizer.eos_token_id
102
-
103
- # compile
104
- time_now = time.time()
105
- self.model = torch.compile(self.model)
106
- logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
107
-
108
- self.model.eval()
109
- if self.pad_token_id is None:
110
- self.tokenizer.pad_token = self.tokenizer.eos_token
111
- self.model.config.pad_token_id = self.model.config.eos_token_id
112
- self.tokenizer.truncation_side = "left"
113
-
114
- self.past_key_values_kwargs = {}
115
- self.past_token_ids = None
116
-
117
- def truncate_kv_cache(self, new_token_ids: torch.Tensor):
118
- """
119
- Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
120
- Uses:
121
- past_ids: torch.Tensor [B, T]
122
- new_ids: torch.Tensor [B, T]
123
- kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
124
-
125
- NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
126
- """
127
- if (
128
- self.past_token_ids is None
129
- or "past_key_values" not in self.past_key_values_kwargs
130
- ):
131
- return
132
-
133
- # caching doesn't seem to work with multimodal models
134
- if self.use_images:
135
- self.past_key_values_kwargs = {}
136
- return
137
-
138
- past_batch_size, past_seq_len = self.past_token_ids.shape
139
- new_batch_size, new_seq_len = new_token_ids.shape
140
-
141
- # If the batch size has changed, reset the cache
142
- if past_batch_size != new_batch_size:
143
- self.past_key_values_kwargs = {}
144
- return
145
-
146
- min_shape = min(past_seq_len, new_seq_len)
147
- compare_past = (
148
- self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
149
- )
150
-
151
- # All tokens are the same - no need to truncate
152
- if not compare_past.any():
153
- return
154
-
155
- # Find the first token that is different between the past and new tokens
156
- seq_min = torch.argmax(compare_past.double(), dim=1).min()
157
-
158
- # Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
159
- # assumes shape is [num_layers, num_heads, seq_len, hidden_size]
160
- self.past_key_values_kwargs["past_key_values"] = [
161
- [kv[:, :, :seq_min, :] for kv in kvs]
162
- for kvs in self.past_key_values_kwargs["past_key_values"]
163
- ]
164
-
165
- @staticmethod
166
- def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
167
- model_kwargs = {
168
- "token": os.getenv("HF_TOKEN"),
169
- # "attn_implementation": "flash_attention_2",
170
- # "trust_remote_code": True,
171
- }
172
- quantize = kwargs.get("quantize", False)
173
- if quantize == "int4":
174
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
175
- load_in_4bit=True,
176
- )
177
- elif quantize == "int8":
178
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
179
- load_in_8bit=True,
180
- )
181
- else:
182
- model_kwargs["torch_dtype"] = torch.bfloat16
183
-
184
- downloaded_models = get_downloaded_models()
185
- if model_name in downloaded_models:
186
- model_kwargs["local_files_only"] = True
187
- model_name = downloaded_models[model_name]
188
- logger.info(f"Using local model {model_name}")
189
- if "/plancraft/outputs" in model_name:
190
- model_kwargs["local_files_only"] = True
191
- logger.info(f"Using local model {model_name}")
192
-
193
- return model_name, model_kwargs
194
-
195
- def reset(self):
196
- # NOTE: past_key_values cache with a rolling window
197
- # is not maximally useful as the beggining shifts over time
198
- # and therefore cache is invalidated
199
- self.past_key_values_kwargs = {}
200
- self.past_token_ids = None
201
-
202
- def prepare_messages(
203
- self,
204
- history: History,
205
- max_messages_window: int,
206
- system_prompt: dict = None,
207
- prompt_images: list = [],
208
- ) -> tuple[list[dict], list]:
209
- """
210
- Prepare the messages using a history
211
- """
212
- message_window = history.dialogue_history[-max_messages_window:]
213
- # remove the first assistant message if it is present
214
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
215
- message_window = message_window[1:]
216
- # add the system prompt if the first message is not a system message
217
- if message_window[0]["role"] != "system" and system_prompt is not None:
218
- message_window = [system_prompt] + message_window
219
-
220
- image_window = []
221
- if self.use_images:
222
- image_list = prompt_images + history.images
223
- image_count = 0
224
- # iterate through the messages in reverse order to assign images
225
- for m in message_window:
226
- for content in m["content"]:
227
- if content["type"] == "image":
228
- image_count += 1
229
- assert image_count <= len(image_list), "Too many images"
230
- image_window = image_list[-image_count:]
231
- image_window = [Image.fromarray(img) for img in image_window]
232
-
233
- return message_window, image_window
234
-
235
- @torch.inference_mode()
236
- def generate_unconstrained(
237
- self,
238
- batch_messages: list[list[dict]],
239
- start_messages_generation: str = "",
240
- max_tokens: int = 256,
241
- temperature=0.6,
242
- **kwargs,
243
- ) -> tuple[list[str], int]:
244
- """
245
- Generate unconstrained text based on the batch of messages.
246
- """
247
- if self.use_images:
248
- assert "images" in kwargs, "Images required for multimodal model"
249
-
250
- tokenized_messages = tokenize(
251
- self.model,
252
- self.tokenizer,
253
- batch_messages,
254
- start_messages_generation=[start_messages_generation] * len(batch_messages),
255
- max_tokens=max_tokens,
256
- images=kwargs.get("images") if self.use_images else None,
257
- )
258
- prompt_tokens = tokenized_messages["input_ids"].shape[-1]
259
-
260
- # Sent to the same device as model
261
- tokenized_messages = {
262
- k: v.to(self.model.device) for k, v in tokenized_messages.items()
263
- }
264
-
265
- # Truncate the key-value cache
266
- self.truncate_kv_cache(tokenized_messages["input_ids"])
267
-
268
- if (
269
- "past_key_values" in self.past_key_values_kwargs
270
- and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
271
- > tokenized_messages["input_ids"].shape[-1]
272
- ):
273
- raise ValueError("Past key values are larger than the input_ids")
274
-
275
- past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
276
- if past_key_values is not None:
277
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
278
-
279
- generated_sequences = self.model.generate(
280
- **tokenized_messages,
281
- do_sample=True,
282
- temperature=temperature,
283
- max_new_tokens=max_tokens,
284
- pad_token_id=self.pad_token_id,
285
- return_dict_in_generate=True,
286
- use_cache=True,
287
- past_key_values=past_key_values,
288
- return_legacy_cache=True,
289
- )
290
- # Cache the past key values
291
- if self.use_hot_cache:
292
- self.past_key_values_kwargs["past_key_values"] = (
293
- generated_sequences.past_key_values
294
- )
295
- self.past_token_ids = generated_sequences.sequences
296
-
297
- # Decode the output
298
- text_responses = self.tokenizer.batch_decode(
299
- generated_sequences.sequences[:, prompt_tokens:],
300
- skip_special_tokens=False,
301
- )
302
-
303
- text_responses = [
304
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
305
- ]
306
-
307
- _, total_tokens_used = generated_sequences.sequences.shape
308
- return text_responses, total_tokens_used
309
-
310
-
311
- class OpenAIGenerator:
312
- def __init__(self, use_images=False, model_name="gpt-4o-mini"):
313
- self.client = OpenAI()
314
- self.use_images = use_images
315
- self.model_name = model_name
316
-
317
- def reset(self):
318
- pass
319
-
320
- def prepare_messages(
321
- self,
322
- history: History,
323
- max_messages_window: int,
324
- system_prompt: dict = None,
325
- prompt_images: list = [],
326
- ) -> tuple[list[dict], list]:
327
- """
328
- Prepare the image messages for the model
329
- """
330
- message_window = history.dialogue_history[-max_messages_window:]
331
- # remove the first assistant message if it is present
332
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
333
- message_window = message_window[1:]
334
- # add the system prompt if the first message is not a system message
335
- if message_window[0]["role"] != "system" and system_prompt is not None:
336
- message_window = [system_prompt] + message_window
337
-
338
- if self.use_images:
339
- image_list = prompt_images + history.images
340
- img_idx = -1
341
- seen_images = 0
342
- # iterate through the messages in reverse order to assign images
343
- for i in range(len(message_window) - 1, -1, -1):
344
- new_content_list = []
345
- for content in message_window[i]["content"]:
346
- if content["type"] == "text":
347
- new_content_list.append(content)
348
- elif content["type"] == "image":
349
- base64_image = numpy_to_base64(image_list[img_idx])
350
- img_idx -= 1
351
- seen_images + 1
352
- new_content = {
353
- "type": "image_url",
354
- "image_url": {
355
- "url": f"data:image/jpeg;base64,{base64_image}"
356
- },
357
- }
358
- new_content_list.append(new_content)
359
- message_window[i]["content"] = new_content_list
360
- assert seen_images <= len(image_list), "Too many images"
361
-
362
- return message_window, []
363
-
364
- def generate_unconstrained(
365
- self,
366
- batch_messages: list[list[dict]],
367
- max_tokens=256,
368
- **kwargs,
369
- ) -> tuple[list[str], int]:
370
- contents = []
371
- tokens_used = 0
372
- for messages in batch_messages:
373
- response = self.client.chat.completions.create(
374
- model=self.model_name,
375
- messages=messages,
376
- temperature=0.0,
377
- max_tokens=max_tokens,
378
- top_p=1,
379
- frequency_penalty=0,
380
- presence_penalty=0,
381
- stop=["\n", "\n\n"],
382
- )
383
- content = response.choices[0].message.content
384
- tokens_used += response.usage.total_tokens
385
- contents.append(content)
386
- return contents, tokens_used
387
-
388
-
389
- class OAMGenerator:
390
- def __init__(
391
- self,
392
- model_name,
393
- ):
394
- self.model_name = model_name
395
- logger.info("Loading model")
396
- time_now = time.time()
397
- self.model = PlancraftOAM.from_pretrained(model_name)
398
- self.model.cuda()
399
- logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
400
- # compile
401
- time_now = time.time()
402
- self.model = torch.compile(self.model)
403
- logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
404
-
405
- self.model.eval()
406
- self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
407
- self.model.tokenizer.truncation_side = "left"
408
-
409
- def reset(self):
410
- pass
411
-
412
- def prepare_messages(
413
- self,
414
- history: History,
415
- max_messages_window: int,
416
- system_prompt: dict = None,
417
- prompt_images: list = [],
418
- ) -> tuple[list[dict], list]:
419
- """
420
- Prepare the messages using a history
421
- """
422
- message_window = history.dialogue_history[-max_messages_window:]
423
- # remove the first assistant message if it is present
424
- if len(message_window) > 0 and message_window[0]["role"] == "assistant":
425
- message_window = message_window[1:]
426
- # add the system prompt if the first message is not a system message
427
- if message_window[0]["role"] != "system" and system_prompt is not None:
428
- message_window = [system_prompt] + message_window
429
-
430
- image_window = []
431
-
432
- image_list = prompt_images + history.images
433
- image_count = 0
434
-
435
- # iterate through the messages to count how many images are present
436
- new_message_window = []
437
- for m in message_window:
438
- message_content = m["content"]
439
- message_role = m["role"]
440
-
441
- if "\ninventory:\n" in message_content:
442
- message_content = (
443
- message_content.split("\ninventory:\n")[0]
444
- + "\ninventory:<|inventory|>"
445
- )
446
-
447
- if "<|inventory|>" in message_content:
448
- image_count += 1
449
-
450
- new_message_window.append(
451
- {"role": message_role, "content": message_content}
452
- )
453
-
454
- assert image_count <= len(image_list), "Too many images"
455
- # get messages from end of queue
456
- image_window = image_list[-image_count:]
457
- image_window = [Image.fromarray(img) for img in image_window]
458
-
459
- return new_message_window, image_window
460
-
461
- @torch.inference_mode()
462
- def generate_unconstrained(
463
- self,
464
- batch_messages: list[list[dict]],
465
- images: list,
466
- max_tokens: int = 256,
467
- temperature=0.6,
468
- **kwargs,
469
- ) -> tuple[list[str], int]:
470
- """
471
- Generate unconstrained text based on the batch of messages.
472
- """
473
- text_responses, total_tokens_used = self.model.generate(
474
- batch_messages=batch_messages,
475
- batch_images=images,
476
- do_sample=True,
477
- temperature=temperature,
478
- max_new_tokens=max_tokens,
479
- )
480
- return text_responses, total_tokens_used