plancraft 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
models/generators.py
ADDED
@@ -0,0 +1,483 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from dotenv import load_dotenv
|
7
|
+
from openai import OpenAI
|
8
|
+
from PIL import Image
|
9
|
+
from transformers import (
|
10
|
+
AutoModelForCausalLM,
|
11
|
+
AutoModelForVision2Seq,
|
12
|
+
AutoProcessor,
|
13
|
+
AutoTokenizer,
|
14
|
+
BitsAndBytesConfig,
|
15
|
+
)
|
16
|
+
from transformers.cache_utils import DynamicCache
|
17
|
+
|
18
|
+
from plancraft.models.base import History
|
19
|
+
from plancraft.models.oam import PlancraftOAM
|
20
|
+
from plancraft.models.utils import (
|
21
|
+
get_downloaded_models,
|
22
|
+
numpy_to_base64,
|
23
|
+
tokenize,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
load_dotenv()
|
30
|
+
|
31
|
+
|
32
|
+
class TransformersGenerator:
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
model_name: str,
|
36
|
+
tokenizer_name: str = "same",
|
37
|
+
quantize=False,
|
38
|
+
use_images=False,
|
39
|
+
use_hot_cache=True,
|
40
|
+
adapter_name="",
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
self.model_name = model_name
|
44
|
+
self.use_hot_cache = use_hot_cache
|
45
|
+
|
46
|
+
if tokenizer_name == "same":
|
47
|
+
tokenizer_name = model_name
|
48
|
+
|
49
|
+
self.use_images = use_images
|
50
|
+
model_name, model_kwargs = self.build_model_kwargs(
|
51
|
+
model_name, quantize=quantize
|
52
|
+
)
|
53
|
+
self.processor = None
|
54
|
+
if "idefics" in model_name:
|
55
|
+
assert use_images, "Idefics model requires multimodal input"
|
56
|
+
self.tokenizer = AutoProcessor.from_pretrained(
|
57
|
+
tokenizer_name,
|
58
|
+
**model_kwargs,
|
59
|
+
)
|
60
|
+
self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
|
61
|
+
logger.info("Loading model")
|
62
|
+
time_now = time.time()
|
63
|
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
64
|
+
model_name,
|
65
|
+
device_map="auto",
|
66
|
+
**model_kwargs,
|
67
|
+
)
|
68
|
+
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
69
|
+
# set pad_token_id
|
70
|
+
if self.tokenizer.tokenizer.pad_token_id:
|
71
|
+
self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
|
72
|
+
else:
|
73
|
+
self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
|
74
|
+
else:
|
75
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
76
|
+
tokenizer_name,
|
77
|
+
token=os.getenv("HF_TOKEN"), # trust_remote_code=True
|
78
|
+
padding_side="left", # ensure that the padding is on the left
|
79
|
+
)
|
80
|
+
logger.info("Loading model")
|
81
|
+
time_now = time.time()
|
82
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
83
|
+
model_name,
|
84
|
+
device_map="auto",
|
85
|
+
**model_kwargs,
|
86
|
+
)
|
87
|
+
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
88
|
+
|
89
|
+
# load OA adapter
|
90
|
+
if adapter_name != "":
|
91
|
+
logger.info(f"Loading adapter and tokenizer from {adapter_name}")
|
92
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
93
|
+
adapter_name,
|
94
|
+
padding_side="left",
|
95
|
+
)
|
96
|
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
97
|
+
self.model.load_adapter(adapter_name)
|
98
|
+
|
99
|
+
# set pad_token_id
|
100
|
+
if self.tokenizer.pad_token_id:
|
101
|
+
self.pad_token_id = self.tokenizer.pad_token_id
|
102
|
+
else:
|
103
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
104
|
+
self.pad_token_id = self.tokenizer.eos_token_id
|
105
|
+
|
106
|
+
# compile
|
107
|
+
time_now = time.time()
|
108
|
+
self.model = torch.compile(self.model)
|
109
|
+
logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
|
110
|
+
|
111
|
+
self.model.eval()
|
112
|
+
if self.pad_token_id is None:
|
113
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
114
|
+
self.model.config.pad_token_id = self.model.config.eos_token_id
|
115
|
+
self.tokenizer.truncation_side = "left"
|
116
|
+
|
117
|
+
self.past_key_values_kwargs = {}
|
118
|
+
self.past_token_ids = None
|
119
|
+
|
120
|
+
def truncate_kv_cache(self, new_token_ids: torch.Tensor):
|
121
|
+
"""
|
122
|
+
Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
|
123
|
+
Uses:
|
124
|
+
past_ids: torch.Tensor [B, T]
|
125
|
+
new_ids: torch.Tensor [B, T]
|
126
|
+
kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
|
127
|
+
|
128
|
+
NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
|
129
|
+
"""
|
130
|
+
if (
|
131
|
+
self.past_token_ids is None
|
132
|
+
or "past_key_values" not in self.past_key_values_kwargs
|
133
|
+
):
|
134
|
+
return
|
135
|
+
|
136
|
+
# caching doesn't seem to work with multimodal models
|
137
|
+
if self.use_images:
|
138
|
+
self.past_key_values_kwargs = {}
|
139
|
+
return
|
140
|
+
|
141
|
+
past_batch_size, past_seq_len = self.past_token_ids.shape
|
142
|
+
new_batch_size, new_seq_len = new_token_ids.shape
|
143
|
+
|
144
|
+
# If the batch size has changed, reset the cache
|
145
|
+
if past_batch_size != new_batch_size:
|
146
|
+
self.past_key_values_kwargs = {}
|
147
|
+
return
|
148
|
+
|
149
|
+
min_shape = min(past_seq_len, new_seq_len)
|
150
|
+
compare_past = (
|
151
|
+
self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
|
152
|
+
)
|
153
|
+
|
154
|
+
# All tokens are the same - no need to truncate
|
155
|
+
if not compare_past.any():
|
156
|
+
return
|
157
|
+
|
158
|
+
# Find the first token that is different between the past and new tokens
|
159
|
+
seq_min = torch.argmax(compare_past.double(), dim=1).min()
|
160
|
+
|
161
|
+
# Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
|
162
|
+
# assumes shape is [num_layers, num_heads, seq_len, hidden_size]
|
163
|
+
self.past_key_values_kwargs["past_key_values"] = [
|
164
|
+
[kv[:, :, :seq_min, :] for kv in kvs]
|
165
|
+
for kvs in self.past_key_values_kwargs["past_key_values"]
|
166
|
+
]
|
167
|
+
|
168
|
+
@staticmethod
|
169
|
+
def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
|
170
|
+
model_kwargs = {
|
171
|
+
"token": os.getenv("HF_TOKEN"),
|
172
|
+
# "attn_implementation": "flash_attention_2",
|
173
|
+
# "trust_remote_code": True,
|
174
|
+
}
|
175
|
+
quantize = kwargs.get("quantize", False)
|
176
|
+
if quantize == "int4":
|
177
|
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
178
|
+
load_in_4bit=True,
|
179
|
+
)
|
180
|
+
elif quantize == "int8":
|
181
|
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
182
|
+
load_in_8bit=True,
|
183
|
+
)
|
184
|
+
else:
|
185
|
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
186
|
+
|
187
|
+
downloaded_models = get_downloaded_models()
|
188
|
+
if model_name in downloaded_models:
|
189
|
+
model_kwargs["local_files_only"] = True
|
190
|
+
model_name = downloaded_models[model_name]
|
191
|
+
logger.info(f"Using local model {model_name}")
|
192
|
+
if "/plancraft/outputs" in model_name:
|
193
|
+
model_kwargs["local_files_only"] = True
|
194
|
+
logger.info(f"Using local model {model_name}")
|
195
|
+
|
196
|
+
return model_name, model_kwargs
|
197
|
+
|
198
|
+
def reset(self):
|
199
|
+
# NOTE: past_key_values cache with a rolling window
|
200
|
+
# is not maximally useful as the beggining shifts over time
|
201
|
+
# and therefore cache is invalidated
|
202
|
+
self.past_key_values_kwargs = {}
|
203
|
+
self.past_token_ids = None
|
204
|
+
|
205
|
+
def prepare_messages(
|
206
|
+
self,
|
207
|
+
history: History,
|
208
|
+
max_messages_window: int,
|
209
|
+
system_prompt: dict = None,
|
210
|
+
prompt_images: list = [],
|
211
|
+
) -> tuple[list[dict], list]:
|
212
|
+
"""
|
213
|
+
Prepare the messages using a history
|
214
|
+
"""
|
215
|
+
message_window = history.dialogue_history[-max_messages_window:]
|
216
|
+
# remove the first assistant message if it is present
|
217
|
+
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
218
|
+
message_window = message_window[1:]
|
219
|
+
# add the system prompt if the first message is not a system message
|
220
|
+
if message_window[0]["role"] != "system" and system_prompt is not None:
|
221
|
+
message_window = [system_prompt] + message_window
|
222
|
+
|
223
|
+
image_window = []
|
224
|
+
if self.use_images:
|
225
|
+
image_list = prompt_images + history.images
|
226
|
+
image_count = 0
|
227
|
+
# iterate through the messages in reverse order to assign images
|
228
|
+
for m in message_window:
|
229
|
+
for content in m["content"]:
|
230
|
+
if content["type"] == "image":
|
231
|
+
image_count += 1
|
232
|
+
assert image_count <= len(image_list), "Too many images"
|
233
|
+
image_window = image_list[-image_count:]
|
234
|
+
image_window = [Image.fromarray(img) for img in image_window]
|
235
|
+
|
236
|
+
return message_window, image_window
|
237
|
+
|
238
|
+
@torch.inference_mode()
|
239
|
+
def generate_unconstrained(
|
240
|
+
self,
|
241
|
+
batch_messages: list[list[dict]],
|
242
|
+
start_messages_generation: str = "",
|
243
|
+
max_tokens: int = 256,
|
244
|
+
temperature=0.6,
|
245
|
+
**kwargs,
|
246
|
+
) -> tuple[list[str], int]:
|
247
|
+
"""
|
248
|
+
Generate unconstrained text based on the batch of messages.
|
249
|
+
"""
|
250
|
+
if self.use_images:
|
251
|
+
assert "images" in kwargs, "Images required for multimodal model"
|
252
|
+
|
253
|
+
tokenized_messages = tokenize(
|
254
|
+
self.model,
|
255
|
+
self.tokenizer,
|
256
|
+
batch_messages,
|
257
|
+
start_messages_generation=[start_messages_generation] * len(batch_messages),
|
258
|
+
max_tokens=max_tokens,
|
259
|
+
images=kwargs.get("images") if self.use_images else None,
|
260
|
+
)
|
261
|
+
prompt_tokens = tokenized_messages["input_ids"].shape[-1]
|
262
|
+
|
263
|
+
# Sent to the same device as model
|
264
|
+
tokenized_messages = {
|
265
|
+
k: v.to(self.model.device) for k, v in tokenized_messages.items()
|
266
|
+
}
|
267
|
+
|
268
|
+
# Truncate the key-value cache
|
269
|
+
self.truncate_kv_cache(tokenized_messages["input_ids"])
|
270
|
+
|
271
|
+
if (
|
272
|
+
"past_key_values" in self.past_key_values_kwargs
|
273
|
+
and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
|
274
|
+
> tokenized_messages["input_ids"].shape[-1]
|
275
|
+
):
|
276
|
+
raise ValueError("Past key values are larger than the input_ids")
|
277
|
+
|
278
|
+
past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
|
279
|
+
if past_key_values is not None:
|
280
|
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
281
|
+
|
282
|
+
generated_sequences = self.model.generate(
|
283
|
+
**tokenized_messages,
|
284
|
+
do_sample=True,
|
285
|
+
temperature=temperature,
|
286
|
+
max_new_tokens=max_tokens,
|
287
|
+
pad_token_id=self.pad_token_id,
|
288
|
+
return_dict_in_generate=True,
|
289
|
+
use_cache=True,
|
290
|
+
past_key_values=past_key_values,
|
291
|
+
return_legacy_cache=True,
|
292
|
+
)
|
293
|
+
# Cache the past key values
|
294
|
+
if self.use_hot_cache:
|
295
|
+
self.past_key_values_kwargs["past_key_values"] = (
|
296
|
+
generated_sequences.past_key_values
|
297
|
+
)
|
298
|
+
self.past_token_ids = generated_sequences.sequences
|
299
|
+
|
300
|
+
# Decode the output
|
301
|
+
text_responses = self.tokenizer.batch_decode(
|
302
|
+
generated_sequences.sequences[:, prompt_tokens:],
|
303
|
+
skip_special_tokens=False,
|
304
|
+
)
|
305
|
+
|
306
|
+
text_responses = [
|
307
|
+
text_response.replace("<|eot_id|>", "") for text_response in text_responses
|
308
|
+
]
|
309
|
+
|
310
|
+
_, total_tokens_used = generated_sequences.sequences.shape
|
311
|
+
return text_responses, total_tokens_used
|
312
|
+
|
313
|
+
|
314
|
+
class OpenAIGenerator:
|
315
|
+
def __init__(self, use_images=False, model_name="gpt-4o-mini"):
|
316
|
+
self.client = OpenAI()
|
317
|
+
self.use_images = use_images
|
318
|
+
self.model_name = model_name
|
319
|
+
|
320
|
+
def reset(self):
|
321
|
+
pass
|
322
|
+
|
323
|
+
def prepare_messages(
|
324
|
+
self,
|
325
|
+
history: History,
|
326
|
+
max_messages_window: int,
|
327
|
+
system_prompt: dict = None,
|
328
|
+
prompt_images: list = [],
|
329
|
+
) -> tuple[list[dict], list]:
|
330
|
+
"""
|
331
|
+
Prepare the image messages for the model
|
332
|
+
"""
|
333
|
+
message_window = history.dialogue_history[-max_messages_window:]
|
334
|
+
# remove the first assistant message if it is present
|
335
|
+
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
336
|
+
message_window = message_window[1:]
|
337
|
+
# add the system prompt if the first message is not a system message
|
338
|
+
if message_window[0]["role"] != "system" and system_prompt is not None:
|
339
|
+
message_window = [system_prompt] + message_window
|
340
|
+
|
341
|
+
if self.use_images:
|
342
|
+
image_list = prompt_images + history.images
|
343
|
+
img_idx = -1
|
344
|
+
seen_images = 0
|
345
|
+
# iterate through the messages in reverse order to assign images
|
346
|
+
for i in range(len(message_window) - 1, -1, -1):
|
347
|
+
new_content_list = []
|
348
|
+
for content in message_window[i]["content"]:
|
349
|
+
if content["type"] == "text":
|
350
|
+
new_content_list.append(content)
|
351
|
+
elif content["type"] == "image":
|
352
|
+
base64_image = numpy_to_base64(image_list[img_idx])
|
353
|
+
img_idx -= 1
|
354
|
+
seen_images + 1
|
355
|
+
new_content = {
|
356
|
+
"type": "image_url",
|
357
|
+
"image_url": {
|
358
|
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
359
|
+
},
|
360
|
+
}
|
361
|
+
new_content_list.append(new_content)
|
362
|
+
message_window[i]["content"] = new_content_list
|
363
|
+
assert seen_images <= len(image_list), "Too many images"
|
364
|
+
|
365
|
+
return message_window, []
|
366
|
+
|
367
|
+
def generate_unconstrained(
|
368
|
+
self,
|
369
|
+
batch_messages: list[list[dict]],
|
370
|
+
max_tokens=256,
|
371
|
+
**kwargs,
|
372
|
+
) -> tuple[list[str], int]:
|
373
|
+
contents = []
|
374
|
+
tokens_used = 0
|
375
|
+
for messages in batch_messages:
|
376
|
+
response = self.client.chat.completions.create(
|
377
|
+
model=self.model_name,
|
378
|
+
messages=messages,
|
379
|
+
temperature=0.0,
|
380
|
+
max_tokens=max_tokens,
|
381
|
+
top_p=1,
|
382
|
+
frequency_penalty=0,
|
383
|
+
presence_penalty=0,
|
384
|
+
stop=["\n", "\n\n"],
|
385
|
+
)
|
386
|
+
content = response.choices[0].message.content
|
387
|
+
tokens_used += response.usage.total_tokens
|
388
|
+
contents.append(content)
|
389
|
+
return contents, tokens_used
|
390
|
+
|
391
|
+
|
392
|
+
class OAMGenerator:
|
393
|
+
def __init__(
|
394
|
+
self,
|
395
|
+
model_name,
|
396
|
+
):
|
397
|
+
self.model_name = model_name
|
398
|
+
logger.info("Loading model")
|
399
|
+
time_now = time.time()
|
400
|
+
self.model = PlancraftOAM.from_pretrained(model_name)
|
401
|
+
self.model.cuda()
|
402
|
+
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
403
|
+
# compile
|
404
|
+
time_now = time.time()
|
405
|
+
self.model = torch.compile(self.model)
|
406
|
+
logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
|
407
|
+
|
408
|
+
self.model.eval()
|
409
|
+
self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
|
410
|
+
self.model.tokenizer.truncation_side = "left"
|
411
|
+
|
412
|
+
def reset(self):
|
413
|
+
pass
|
414
|
+
|
415
|
+
def prepare_messages(
|
416
|
+
self,
|
417
|
+
history: History,
|
418
|
+
max_messages_window: int,
|
419
|
+
system_prompt: dict = None,
|
420
|
+
prompt_images: list = [],
|
421
|
+
) -> tuple[list[dict], list]:
|
422
|
+
"""
|
423
|
+
Prepare the messages using a history
|
424
|
+
"""
|
425
|
+
message_window = history.dialogue_history[-max_messages_window:]
|
426
|
+
# remove the first assistant message if it is present
|
427
|
+
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
428
|
+
message_window = message_window[1:]
|
429
|
+
# add the system prompt if the first message is not a system message
|
430
|
+
if message_window[0]["role"] != "system" and system_prompt is not None:
|
431
|
+
message_window = [system_prompt] + message_window
|
432
|
+
|
433
|
+
image_window = []
|
434
|
+
|
435
|
+
image_list = prompt_images + history.images
|
436
|
+
image_count = 0
|
437
|
+
|
438
|
+
# iterate through the messages to count how many images are present
|
439
|
+
new_message_window = []
|
440
|
+
for m in message_window:
|
441
|
+
message_content = m["content"]
|
442
|
+
message_role = m["role"]
|
443
|
+
|
444
|
+
if "\ninventory:\n" in message_content:
|
445
|
+
message_content = (
|
446
|
+
message_content.split("\ninventory:\n")[0]
|
447
|
+
+ "\ninventory:<|inventory|>"
|
448
|
+
)
|
449
|
+
|
450
|
+
if "<|inventory|>" in message_content:
|
451
|
+
image_count += 1
|
452
|
+
|
453
|
+
new_message_window.append(
|
454
|
+
{"role": message_role, "content": message_content}
|
455
|
+
)
|
456
|
+
|
457
|
+
assert image_count <= len(image_list), "Too many images"
|
458
|
+
# get messages from end of queue
|
459
|
+
image_window = image_list[-image_count:]
|
460
|
+
image_window = [Image.fromarray(img) for img in image_window]
|
461
|
+
|
462
|
+
return new_message_window, image_window
|
463
|
+
|
464
|
+
@torch.inference_mode()
|
465
|
+
def generate_unconstrained(
|
466
|
+
self,
|
467
|
+
batch_messages: list[list[dict]],
|
468
|
+
images: list,
|
469
|
+
max_tokens: int = 256,
|
470
|
+
temperature=0.6,
|
471
|
+
**kwargs,
|
472
|
+
) -> tuple[list[str], int]:
|
473
|
+
"""
|
474
|
+
Generate unconstrained text based on the batch of messages.
|
475
|
+
"""
|
476
|
+
text_responses, total_tokens_used = self.model.generate(
|
477
|
+
batch_messages=batch_messages,
|
478
|
+
batch_images=images,
|
479
|
+
do_sample=True,
|
480
|
+
temperature=temperature,
|
481
|
+
max_new_tokens=max_tokens,
|
482
|
+
)
|
483
|
+
return text_responses, total_tokens_used
|