plancraft 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
models/generators.py
ADDED
@@ -0,0 +1,483 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from dotenv import load_dotenv
|
7
|
+
from openai import OpenAI
|
8
|
+
from PIL import Image
|
9
|
+
from transformers import (
|
10
|
+
AutoModelForCausalLM,
|
11
|
+
AutoModelForVision2Seq,
|
12
|
+
AutoProcessor,
|
13
|
+
AutoTokenizer,
|
14
|
+
BitsAndBytesConfig,
|
15
|
+
)
|
16
|
+
from transformers.cache_utils import DynamicCache
|
17
|
+
|
18
|
+
from plancraft.models.base import History
|
19
|
+
from plancraft.models.oam import PlancraftOAM
|
20
|
+
from plancraft.models.utils import (
|
21
|
+
get_downloaded_models,
|
22
|
+
numpy_to_base64,
|
23
|
+
tokenize,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
load_dotenv()
|
30
|
+
|
31
|
+
|
32
|
+
class TransformersGenerator:
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
model_name: str,
|
36
|
+
tokenizer_name: str = "same",
|
37
|
+
quantize=False,
|
38
|
+
use_images=False,
|
39
|
+
use_hot_cache=True,
|
40
|
+
adapter_name="",
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
self.model_name = model_name
|
44
|
+
self.use_hot_cache = use_hot_cache
|
45
|
+
|
46
|
+
if tokenizer_name == "same":
|
47
|
+
tokenizer_name = model_name
|
48
|
+
|
49
|
+
self.use_images = use_images
|
50
|
+
model_name, model_kwargs = self.build_model_kwargs(
|
51
|
+
model_name, quantize=quantize
|
52
|
+
)
|
53
|
+
self.processor = None
|
54
|
+
if "idefics" in model_name:
|
55
|
+
assert use_images, "Idefics model requires multimodal input"
|
56
|
+
self.tokenizer = AutoProcessor.from_pretrained(
|
57
|
+
tokenizer_name,
|
58
|
+
**model_kwargs,
|
59
|
+
)
|
60
|
+
self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
|
61
|
+
logger.info("Loading model")
|
62
|
+
time_now = time.time()
|
63
|
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
64
|
+
model_name,
|
65
|
+
device_map="auto",
|
66
|
+
**model_kwargs,
|
67
|
+
)
|
68
|
+
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
69
|
+
# set pad_token_id
|
70
|
+
if self.tokenizer.tokenizer.pad_token_id:
|
71
|
+
self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
|
72
|
+
else:
|
73
|
+
self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
|
74
|
+
else:
|
75
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
76
|
+
tokenizer_name,
|
77
|
+
token=os.getenv("HF_TOKEN"), # trust_remote_code=True
|
78
|
+
padding_side="left", # ensure that the padding is on the left
|
79
|
+
)
|
80
|
+
logger.info("Loading model")
|
81
|
+
time_now = time.time()
|
82
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
83
|
+
model_name,
|
84
|
+
device_map="auto",
|
85
|
+
**model_kwargs,
|
86
|
+
)
|
87
|
+
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
88
|
+
|
89
|
+
# load OA adapter
|
90
|
+
if adapter_name != "":
|
91
|
+
logger.info(f"Loading adapter and tokenizer from {adapter_name}")
|
92
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
93
|
+
adapter_name,
|
94
|
+
padding_side="left",
|
95
|
+
)
|
96
|
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
97
|
+
self.model.load_adapter(adapter_name)
|
98
|
+
|
99
|
+
# set pad_token_id
|
100
|
+
if self.tokenizer.pad_token_id:
|
101
|
+
self.pad_token_id = self.tokenizer.pad_token_id
|
102
|
+
else:
|
103
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
104
|
+
self.pad_token_id = self.tokenizer.eos_token_id
|
105
|
+
|
106
|
+
# compile
|
107
|
+
time_now = time.time()
|
108
|
+
self.model = torch.compile(self.model)
|
109
|
+
logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
|
110
|
+
|
111
|
+
self.model.eval()
|
112
|
+
if self.pad_token_id is None:
|
113
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
114
|
+
self.model.config.pad_token_id = self.model.config.eos_token_id
|
115
|
+
self.tokenizer.truncation_side = "left"
|
116
|
+
|
117
|
+
self.past_key_values_kwargs = {}
|
118
|
+
self.past_token_ids = None
|
119
|
+
|
120
|
+
def truncate_kv_cache(self, new_token_ids: torch.Tensor):
|
121
|
+
"""
|
122
|
+
Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
|
123
|
+
Uses:
|
124
|
+
past_ids: torch.Tensor [B, T]
|
125
|
+
new_ids: torch.Tensor [B, T]
|
126
|
+
kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
|
127
|
+
|
128
|
+
NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
|
129
|
+
"""
|
130
|
+
if (
|
131
|
+
self.past_token_ids is None
|
132
|
+
or "past_key_values" not in self.past_key_values_kwargs
|
133
|
+
):
|
134
|
+
return
|
135
|
+
|
136
|
+
# caching doesn't seem to work with multimodal models
|
137
|
+
if self.use_images:
|
138
|
+
self.past_key_values_kwargs = {}
|
139
|
+
return
|
140
|
+
|
141
|
+
past_batch_size, past_seq_len = self.past_token_ids.shape
|
142
|
+
new_batch_size, new_seq_len = new_token_ids.shape
|
143
|
+
|
144
|
+
# If the batch size has changed, reset the cache
|
145
|
+
if past_batch_size != new_batch_size:
|
146
|
+
self.past_key_values_kwargs = {}
|
147
|
+
return
|
148
|
+
|
149
|
+
min_shape = min(past_seq_len, new_seq_len)
|
150
|
+
compare_past = (
|
151
|
+
self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
|
152
|
+
)
|
153
|
+
|
154
|
+
# All tokens are the same - no need to truncate
|
155
|
+
if not compare_past.any():
|
156
|
+
return
|
157
|
+
|
158
|
+
# Find the first token that is different between the past and new tokens
|
159
|
+
seq_min = torch.argmax(compare_past.double(), dim=1).min()
|
160
|
+
|
161
|
+
# Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
|
162
|
+
# assumes shape is [num_layers, num_heads, seq_len, hidden_size]
|
163
|
+
self.past_key_values_kwargs["past_key_values"] = [
|
164
|
+
[kv[:, :, :seq_min, :] for kv in kvs]
|
165
|
+
for kvs in self.past_key_values_kwargs["past_key_values"]
|
166
|
+
]
|
167
|
+
|
168
|
+
@staticmethod
|
169
|
+
def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
|
170
|
+
model_kwargs = {
|
171
|
+
"token": os.getenv("HF_TOKEN"),
|
172
|
+
# "attn_implementation": "flash_attention_2",
|
173
|
+
# "trust_remote_code": True,
|
174
|
+
}
|
175
|
+
quantize = kwargs.get("quantize", False)
|
176
|
+
if quantize == "int4":
|
177
|
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
178
|
+
load_in_4bit=True,
|
179
|
+
)
|
180
|
+
elif quantize == "int8":
|
181
|
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
182
|
+
load_in_8bit=True,
|
183
|
+
)
|
184
|
+
else:
|
185
|
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
186
|
+
|
187
|
+
downloaded_models = get_downloaded_models()
|
188
|
+
if model_name in downloaded_models:
|
189
|
+
model_kwargs["local_files_only"] = True
|
190
|
+
model_name = downloaded_models[model_name]
|
191
|
+
logger.info(f"Using local model {model_name}")
|
192
|
+
if "/plancraft/outputs" in model_name:
|
193
|
+
model_kwargs["local_files_only"] = True
|
194
|
+
logger.info(f"Using local model {model_name}")
|
195
|
+
|
196
|
+
return model_name, model_kwargs
|
197
|
+
|
198
|
+
def reset(self):
|
199
|
+
# NOTE: past_key_values cache with a rolling window
|
200
|
+
# is not maximally useful as the beggining shifts over time
|
201
|
+
# and therefore cache is invalidated
|
202
|
+
self.past_key_values_kwargs = {}
|
203
|
+
self.past_token_ids = None
|
204
|
+
|
205
|
+
def prepare_messages(
|
206
|
+
self,
|
207
|
+
history: History,
|
208
|
+
max_messages_window: int,
|
209
|
+
system_prompt: dict = None,
|
210
|
+
prompt_images: list = [],
|
211
|
+
) -> tuple[list[dict], list]:
|
212
|
+
"""
|
213
|
+
Prepare the messages using a history
|
214
|
+
"""
|
215
|
+
message_window = history.dialogue_history[-max_messages_window:]
|
216
|
+
# remove the first assistant message if it is present
|
217
|
+
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
218
|
+
message_window = message_window[1:]
|
219
|
+
# add the system prompt if the first message is not a system message
|
220
|
+
if message_window[0]["role"] != "system" and system_prompt is not None:
|
221
|
+
message_window = [system_prompt] + message_window
|
222
|
+
|
223
|
+
image_window = []
|
224
|
+
if self.use_images:
|
225
|
+
image_list = prompt_images + history.images
|
226
|
+
image_count = 0
|
227
|
+
# iterate through the messages in reverse order to assign images
|
228
|
+
for m in message_window:
|
229
|
+
for content in m["content"]:
|
230
|
+
if content["type"] == "image":
|
231
|
+
image_count += 1
|
232
|
+
assert image_count <= len(image_list), "Too many images"
|
233
|
+
image_window = image_list[-image_count:]
|
234
|
+
image_window = [Image.fromarray(img) for img in image_window]
|
235
|
+
|
236
|
+
return message_window, image_window
|
237
|
+
|
238
|
+
@torch.inference_mode()
|
239
|
+
def generate_unconstrained(
|
240
|
+
self,
|
241
|
+
batch_messages: list[list[dict]],
|
242
|
+
start_messages_generation: str = "",
|
243
|
+
max_tokens: int = 256,
|
244
|
+
temperature=0.6,
|
245
|
+
**kwargs,
|
246
|
+
) -> tuple[list[str], int]:
|
247
|
+
"""
|
248
|
+
Generate unconstrained text based on the batch of messages.
|
249
|
+
"""
|
250
|
+
if self.use_images:
|
251
|
+
assert "images" in kwargs, "Images required for multimodal model"
|
252
|
+
|
253
|
+
tokenized_messages = tokenize(
|
254
|
+
self.model,
|
255
|
+
self.tokenizer,
|
256
|
+
batch_messages,
|
257
|
+
start_messages_generation=[start_messages_generation] * len(batch_messages),
|
258
|
+
max_tokens=max_tokens,
|
259
|
+
images=kwargs.get("images") if self.use_images else None,
|
260
|
+
)
|
261
|
+
prompt_tokens = tokenized_messages["input_ids"].shape[-1]
|
262
|
+
|
263
|
+
# Sent to the same device as model
|
264
|
+
tokenized_messages = {
|
265
|
+
k: v.to(self.model.device) for k, v in tokenized_messages.items()
|
266
|
+
}
|
267
|
+
|
268
|
+
# Truncate the key-value cache
|
269
|
+
self.truncate_kv_cache(tokenized_messages["input_ids"])
|
270
|
+
|
271
|
+
if (
|
272
|
+
"past_key_values" in self.past_key_values_kwargs
|
273
|
+
and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
|
274
|
+
> tokenized_messages["input_ids"].shape[-1]
|
275
|
+
):
|
276
|
+
raise ValueError("Past key values are larger than the input_ids")
|
277
|
+
|
278
|
+
past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
|
279
|
+
if past_key_values is not None:
|
280
|
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
281
|
+
|
282
|
+
generated_sequences = self.model.generate(
|
283
|
+
**tokenized_messages,
|
284
|
+
do_sample=True,
|
285
|
+
temperature=temperature,
|
286
|
+
max_new_tokens=max_tokens,
|
287
|
+
pad_token_id=self.pad_token_id,
|
288
|
+
return_dict_in_generate=True,
|
289
|
+
use_cache=True,
|
290
|
+
past_key_values=past_key_values,
|
291
|
+
return_legacy_cache=True,
|
292
|
+
)
|
293
|
+
# Cache the past key values
|
294
|
+
if self.use_hot_cache:
|
295
|
+
self.past_key_values_kwargs["past_key_values"] = (
|
296
|
+
generated_sequences.past_key_values
|
297
|
+
)
|
298
|
+
self.past_token_ids = generated_sequences.sequences
|
299
|
+
|
300
|
+
# Decode the output
|
301
|
+
text_responses = self.tokenizer.batch_decode(
|
302
|
+
generated_sequences.sequences[:, prompt_tokens:],
|
303
|
+
skip_special_tokens=False,
|
304
|
+
)
|
305
|
+
|
306
|
+
text_responses = [
|
307
|
+
text_response.replace("<|eot_id|>", "") for text_response in text_responses
|
308
|
+
]
|
309
|
+
|
310
|
+
_, total_tokens_used = generated_sequences.sequences.shape
|
311
|
+
return text_responses, total_tokens_used
|
312
|
+
|
313
|
+
|
314
|
+
class OpenAIGenerator:
|
315
|
+
def __init__(self, use_images=False, model_name="gpt-4o-mini"):
|
316
|
+
self.client = OpenAI()
|
317
|
+
self.use_images = use_images
|
318
|
+
self.model_name = model_name
|
319
|
+
|
320
|
+
def reset(self):
|
321
|
+
pass
|
322
|
+
|
323
|
+
def prepare_messages(
|
324
|
+
self,
|
325
|
+
history: History,
|
326
|
+
max_messages_window: int,
|
327
|
+
system_prompt: dict = None,
|
328
|
+
prompt_images: list = [],
|
329
|
+
) -> tuple[list[dict], list]:
|
330
|
+
"""
|
331
|
+
Prepare the image messages for the model
|
332
|
+
"""
|
333
|
+
message_window = history.dialogue_history[-max_messages_window:]
|
334
|
+
# remove the first assistant message if it is present
|
335
|
+
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
336
|
+
message_window = message_window[1:]
|
337
|
+
# add the system prompt if the first message is not a system message
|
338
|
+
if message_window[0]["role"] != "system" and system_prompt is not None:
|
339
|
+
message_window = [system_prompt] + message_window
|
340
|
+
|
341
|
+
if self.use_images:
|
342
|
+
image_list = prompt_images + history.images
|
343
|
+
img_idx = -1
|
344
|
+
seen_images = 0
|
345
|
+
# iterate through the messages in reverse order to assign images
|
346
|
+
for i in range(len(message_window) - 1, -1, -1):
|
347
|
+
new_content_list = []
|
348
|
+
for content in message_window[i]["content"]:
|
349
|
+
if content["type"] == "text":
|
350
|
+
new_content_list.append(content)
|
351
|
+
elif content["type"] == "image":
|
352
|
+
base64_image = numpy_to_base64(image_list[img_idx])
|
353
|
+
img_idx -= 1
|
354
|
+
seen_images + 1
|
355
|
+
new_content = {
|
356
|
+
"type": "image_url",
|
357
|
+
"image_url": {
|
358
|
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
359
|
+
},
|
360
|
+
}
|
361
|
+
new_content_list.append(new_content)
|
362
|
+
message_window[i]["content"] = new_content_list
|
363
|
+
assert seen_images <= len(image_list), "Too many images"
|
364
|
+
|
365
|
+
return message_window, []
|
366
|
+
|
367
|
+
def generate_unconstrained(
|
368
|
+
self,
|
369
|
+
batch_messages: list[list[dict]],
|
370
|
+
max_tokens=256,
|
371
|
+
**kwargs,
|
372
|
+
) -> tuple[list[str], int]:
|
373
|
+
contents = []
|
374
|
+
tokens_used = 0
|
375
|
+
for messages in batch_messages:
|
376
|
+
response = self.client.chat.completions.create(
|
377
|
+
model=self.model_name,
|
378
|
+
messages=messages,
|
379
|
+
temperature=0.0,
|
380
|
+
max_tokens=max_tokens,
|
381
|
+
top_p=1,
|
382
|
+
frequency_penalty=0,
|
383
|
+
presence_penalty=0,
|
384
|
+
stop=["\n", "\n\n"],
|
385
|
+
)
|
386
|
+
content = response.choices[0].message.content
|
387
|
+
tokens_used += response.usage.total_tokens
|
388
|
+
contents.append(content)
|
389
|
+
return contents, tokens_used
|
390
|
+
|
391
|
+
|
392
|
+
class OAMGenerator:
|
393
|
+
def __init__(
|
394
|
+
self,
|
395
|
+
model_name,
|
396
|
+
):
|
397
|
+
self.model_name = model_name
|
398
|
+
logger.info("Loading model")
|
399
|
+
time_now = time.time()
|
400
|
+
self.model = PlancraftOAM.from_pretrained(model_name)
|
401
|
+
self.model.cuda()
|
402
|
+
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
403
|
+
# compile
|
404
|
+
time_now = time.time()
|
405
|
+
self.model = torch.compile(self.model)
|
406
|
+
logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
|
407
|
+
|
408
|
+
self.model.eval()
|
409
|
+
self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
|
410
|
+
self.model.tokenizer.truncation_side = "left"
|
411
|
+
|
412
|
+
def reset(self):
|
413
|
+
pass
|
414
|
+
|
415
|
+
def prepare_messages(
|
416
|
+
self,
|
417
|
+
history: History,
|
418
|
+
max_messages_window: int,
|
419
|
+
system_prompt: dict = None,
|
420
|
+
prompt_images: list = [],
|
421
|
+
) -> tuple[list[dict], list]:
|
422
|
+
"""
|
423
|
+
Prepare the messages using a history
|
424
|
+
"""
|
425
|
+
message_window = history.dialogue_history[-max_messages_window:]
|
426
|
+
# remove the first assistant message if it is present
|
427
|
+
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
428
|
+
message_window = message_window[1:]
|
429
|
+
# add the system prompt if the first message is not a system message
|
430
|
+
if message_window[0]["role"] != "system" and system_prompt is not None:
|
431
|
+
message_window = [system_prompt] + message_window
|
432
|
+
|
433
|
+
image_window = []
|
434
|
+
|
435
|
+
image_list = prompt_images + history.images
|
436
|
+
image_count = 0
|
437
|
+
|
438
|
+
# iterate through the messages to count how many images are present
|
439
|
+
new_message_window = []
|
440
|
+
for m in message_window:
|
441
|
+
message_content = m["content"]
|
442
|
+
message_role = m["role"]
|
443
|
+
|
444
|
+
if "\ninventory:\n" in message_content:
|
445
|
+
message_content = (
|
446
|
+
message_content.split("\ninventory:\n")[0]
|
447
|
+
+ "\ninventory:<|inventory|>"
|
448
|
+
)
|
449
|
+
|
450
|
+
if "<|inventory|>" in message_content:
|
451
|
+
image_count += 1
|
452
|
+
|
453
|
+
new_message_window.append(
|
454
|
+
{"role": message_role, "content": message_content}
|
455
|
+
)
|
456
|
+
|
457
|
+
assert image_count <= len(image_list), "Too many images"
|
458
|
+
# get messages from end of queue
|
459
|
+
image_window = image_list[-image_count:]
|
460
|
+
image_window = [Image.fromarray(img) for img in image_window]
|
461
|
+
|
462
|
+
return new_message_window, image_window
|
463
|
+
|
464
|
+
@torch.inference_mode()
|
465
|
+
def generate_unconstrained(
|
466
|
+
self,
|
467
|
+
batch_messages: list[list[dict]],
|
468
|
+
images: list,
|
469
|
+
max_tokens: int = 256,
|
470
|
+
temperature=0.6,
|
471
|
+
**kwargs,
|
472
|
+
) -> tuple[list[str], int]:
|
473
|
+
"""
|
474
|
+
Generate unconstrained text based on the batch of messages.
|
475
|
+
"""
|
476
|
+
text_responses, total_tokens_used = self.model.generate(
|
477
|
+
batch_messages=batch_messages,
|
478
|
+
batch_images=images,
|
479
|
+
do_sample=True,
|
480
|
+
temperature=temperature,
|
481
|
+
max_new_tokens=max_tokens,
|
482
|
+
)
|
483
|
+
return text_responses, total_tokens_used
|