plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.0.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -315
- environments/env_symbolic.py +0 -215
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -483
- models/oam.py +0 -284
- models/oracle.py +0 -268
- models/prompts.py +0 -158
- models/react.py +0 -98
- models/utils.py +0 -289
- plancraft-0.1.0.dist-info/METADATA +0 -53
- plancraft-0.1.0.dist-info/RECORD +0 -26
- plancraft-0.1.0.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.0.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
models/generators.py
DELETED
@@ -1,483 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import os
|
3
|
-
import time
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from dotenv import load_dotenv
|
7
|
-
from openai import OpenAI
|
8
|
-
from PIL import Image
|
9
|
-
from transformers import (
|
10
|
-
AutoModelForCausalLM,
|
11
|
-
AutoModelForVision2Seq,
|
12
|
-
AutoProcessor,
|
13
|
-
AutoTokenizer,
|
14
|
-
BitsAndBytesConfig,
|
15
|
-
)
|
16
|
-
from transformers.cache_utils import DynamicCache
|
17
|
-
|
18
|
-
from plancraft.models.base import History
|
19
|
-
from plancraft.models.oam import PlancraftOAM
|
20
|
-
from plancraft.models.utils import (
|
21
|
-
get_downloaded_models,
|
22
|
-
numpy_to_base64,
|
23
|
-
tokenize,
|
24
|
-
)
|
25
|
-
|
26
|
-
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
-
|
29
|
-
load_dotenv()
|
30
|
-
|
31
|
-
|
32
|
-
class TransformersGenerator:
|
33
|
-
def __init__(
|
34
|
-
self,
|
35
|
-
model_name: str,
|
36
|
-
tokenizer_name: str = "same",
|
37
|
-
quantize=False,
|
38
|
-
use_images=False,
|
39
|
-
use_hot_cache=True,
|
40
|
-
adapter_name="",
|
41
|
-
**kwargs,
|
42
|
-
):
|
43
|
-
self.model_name = model_name
|
44
|
-
self.use_hot_cache = use_hot_cache
|
45
|
-
|
46
|
-
if tokenizer_name == "same":
|
47
|
-
tokenizer_name = model_name
|
48
|
-
|
49
|
-
self.use_images = use_images
|
50
|
-
model_name, model_kwargs = self.build_model_kwargs(
|
51
|
-
model_name, quantize=quantize
|
52
|
-
)
|
53
|
-
self.processor = None
|
54
|
-
if "idefics" in model_name:
|
55
|
-
assert use_images, "Idefics model requires multimodal input"
|
56
|
-
self.tokenizer = AutoProcessor.from_pretrained(
|
57
|
-
tokenizer_name,
|
58
|
-
**model_kwargs,
|
59
|
-
)
|
60
|
-
self.tokenizer.eos_token_id = self.tokenizer.tokenizer.eos_token_id
|
61
|
-
logger.info("Loading model")
|
62
|
-
time_now = time.time()
|
63
|
-
self.model = AutoModelForVision2Seq.from_pretrained(
|
64
|
-
model_name,
|
65
|
-
device_map="auto",
|
66
|
-
**model_kwargs,
|
67
|
-
)
|
68
|
-
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
69
|
-
# set pad_token_id
|
70
|
-
if self.tokenizer.tokenizer.pad_token_id:
|
71
|
-
self.pad_token_id = self.tokenizer.tokenizer.pad_token_id
|
72
|
-
else:
|
73
|
-
self.pad_token_id = self.tokenizer.tokenizer.eos_token_id
|
74
|
-
else:
|
75
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
76
|
-
tokenizer_name,
|
77
|
-
token=os.getenv("HF_TOKEN"), # trust_remote_code=True
|
78
|
-
padding_side="left", # ensure that the padding is on the left
|
79
|
-
)
|
80
|
-
logger.info("Loading model")
|
81
|
-
time_now = time.time()
|
82
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
83
|
-
model_name,
|
84
|
-
device_map="auto",
|
85
|
-
**model_kwargs,
|
86
|
-
)
|
87
|
-
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
88
|
-
|
89
|
-
# load OA adapter
|
90
|
-
if adapter_name != "":
|
91
|
-
logger.info(f"Loading adapter and tokenizer from {adapter_name}")
|
92
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
93
|
-
adapter_name,
|
94
|
-
padding_side="left",
|
95
|
-
)
|
96
|
-
self.model.resize_token_embeddings(len(self.tokenizer))
|
97
|
-
self.model.load_adapter(adapter_name)
|
98
|
-
|
99
|
-
# set pad_token_id
|
100
|
-
if self.tokenizer.pad_token_id:
|
101
|
-
self.pad_token_id = self.tokenizer.pad_token_id
|
102
|
-
else:
|
103
|
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
104
|
-
self.pad_token_id = self.tokenizer.eos_token_id
|
105
|
-
|
106
|
-
# compile
|
107
|
-
time_now = time.time()
|
108
|
-
self.model = torch.compile(self.model)
|
109
|
-
logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
|
110
|
-
|
111
|
-
self.model.eval()
|
112
|
-
if self.pad_token_id is None:
|
113
|
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
114
|
-
self.model.config.pad_token_id = self.model.config.eos_token_id
|
115
|
-
self.tokenizer.truncation_side = "left"
|
116
|
-
|
117
|
-
self.past_key_values_kwargs = {}
|
118
|
-
self.past_token_ids = None
|
119
|
-
|
120
|
-
def truncate_kv_cache(self, new_token_ids: torch.Tensor):
|
121
|
-
"""
|
122
|
-
Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
|
123
|
-
Uses:
|
124
|
-
past_ids: torch.Tensor [B, T]
|
125
|
-
new_ids: torch.Tensor [B, T]
|
126
|
-
kv_cache: tuple[tuple[torch.Tensor]]: tuple of key-value cache tensors
|
127
|
-
|
128
|
-
NOTE: this essentially implements System Prompt in the worst case when using batch_size==1
|
129
|
-
"""
|
130
|
-
if (
|
131
|
-
self.past_token_ids is None
|
132
|
-
or "past_key_values" not in self.past_key_values_kwargs
|
133
|
-
):
|
134
|
-
return
|
135
|
-
|
136
|
-
# caching doesn't seem to work with multimodal models
|
137
|
-
if self.use_images:
|
138
|
-
self.past_key_values_kwargs = {}
|
139
|
-
return
|
140
|
-
|
141
|
-
past_batch_size, past_seq_len = self.past_token_ids.shape
|
142
|
-
new_batch_size, new_seq_len = new_token_ids.shape
|
143
|
-
|
144
|
-
# If the batch size has changed, reset the cache
|
145
|
-
if past_batch_size != new_batch_size:
|
146
|
-
self.past_key_values_kwargs = {}
|
147
|
-
return
|
148
|
-
|
149
|
-
min_shape = min(past_seq_len, new_seq_len)
|
150
|
-
compare_past = (
|
151
|
-
self.past_token_ids[:, :min_shape] != new_token_ids[:, :min_shape]
|
152
|
-
)
|
153
|
-
|
154
|
-
# All tokens are the same - no need to truncate
|
155
|
-
if not compare_past.any():
|
156
|
-
return
|
157
|
-
|
158
|
-
# Find the first token that is different between the past and new tokens
|
159
|
-
seq_min = torch.argmax(compare_past.double(), dim=1).min()
|
160
|
-
|
161
|
-
# Truncate the key-value cache to the size which overlap the past_ids with the new_ids.
|
162
|
-
# assumes shape is [num_layers, num_heads, seq_len, hidden_size]
|
163
|
-
self.past_key_values_kwargs["past_key_values"] = [
|
164
|
-
[kv[:, :, :seq_min, :] for kv in kvs]
|
165
|
-
for kvs in self.past_key_values_kwargs["past_key_values"]
|
166
|
-
]
|
167
|
-
|
168
|
-
@staticmethod
|
169
|
-
def build_model_kwargs(model_name: str, **kwargs) -> tuple[str, dict]:
|
170
|
-
model_kwargs = {
|
171
|
-
"token": os.getenv("HF_TOKEN"),
|
172
|
-
# "attn_implementation": "flash_attention_2",
|
173
|
-
# "trust_remote_code": True,
|
174
|
-
}
|
175
|
-
quantize = kwargs.get("quantize", False)
|
176
|
-
if quantize == "int4":
|
177
|
-
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
178
|
-
load_in_4bit=True,
|
179
|
-
)
|
180
|
-
elif quantize == "int8":
|
181
|
-
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
182
|
-
load_in_8bit=True,
|
183
|
-
)
|
184
|
-
else:
|
185
|
-
model_kwargs["torch_dtype"] = torch.bfloat16
|
186
|
-
|
187
|
-
downloaded_models = get_downloaded_models()
|
188
|
-
if model_name in downloaded_models:
|
189
|
-
model_kwargs["local_files_only"] = True
|
190
|
-
model_name = downloaded_models[model_name]
|
191
|
-
logger.info(f"Using local model {model_name}")
|
192
|
-
if "/plancraft/outputs" in model_name:
|
193
|
-
model_kwargs["local_files_only"] = True
|
194
|
-
logger.info(f"Using local model {model_name}")
|
195
|
-
|
196
|
-
return model_name, model_kwargs
|
197
|
-
|
198
|
-
def reset(self):
|
199
|
-
# NOTE: past_key_values cache with a rolling window
|
200
|
-
# is not maximally useful as the beggining shifts over time
|
201
|
-
# and therefore cache is invalidated
|
202
|
-
self.past_key_values_kwargs = {}
|
203
|
-
self.past_token_ids = None
|
204
|
-
|
205
|
-
def prepare_messages(
|
206
|
-
self,
|
207
|
-
history: History,
|
208
|
-
max_messages_window: int,
|
209
|
-
system_prompt: dict = None,
|
210
|
-
prompt_images: list = [],
|
211
|
-
) -> tuple[list[dict], list]:
|
212
|
-
"""
|
213
|
-
Prepare the messages using a history
|
214
|
-
"""
|
215
|
-
message_window = history.dialogue_history[-max_messages_window:]
|
216
|
-
# remove the first assistant message if it is present
|
217
|
-
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
218
|
-
message_window = message_window[1:]
|
219
|
-
# add the system prompt if the first message is not a system message
|
220
|
-
if message_window[0]["role"] != "system" and system_prompt is not None:
|
221
|
-
message_window = [system_prompt] + message_window
|
222
|
-
|
223
|
-
image_window = []
|
224
|
-
if self.use_images:
|
225
|
-
image_list = prompt_images + history.images
|
226
|
-
image_count = 0
|
227
|
-
# iterate through the messages in reverse order to assign images
|
228
|
-
for m in message_window:
|
229
|
-
for content in m["content"]:
|
230
|
-
if content["type"] == "image":
|
231
|
-
image_count += 1
|
232
|
-
assert image_count <= len(image_list), "Too many images"
|
233
|
-
image_window = image_list[-image_count:]
|
234
|
-
image_window = [Image.fromarray(img) for img in image_window]
|
235
|
-
|
236
|
-
return message_window, image_window
|
237
|
-
|
238
|
-
@torch.inference_mode()
|
239
|
-
def generate_unconstrained(
|
240
|
-
self,
|
241
|
-
batch_messages: list[list[dict]],
|
242
|
-
start_messages_generation: str = "",
|
243
|
-
max_tokens: int = 256,
|
244
|
-
temperature=0.6,
|
245
|
-
**kwargs,
|
246
|
-
) -> tuple[list[str], int]:
|
247
|
-
"""
|
248
|
-
Generate unconstrained text based on the batch of messages.
|
249
|
-
"""
|
250
|
-
if self.use_images:
|
251
|
-
assert "images" in kwargs, "Images required for multimodal model"
|
252
|
-
|
253
|
-
tokenized_messages = tokenize(
|
254
|
-
self.model,
|
255
|
-
self.tokenizer,
|
256
|
-
batch_messages,
|
257
|
-
start_messages_generation=[start_messages_generation] * len(batch_messages),
|
258
|
-
max_tokens=max_tokens,
|
259
|
-
images=kwargs.get("images") if self.use_images else None,
|
260
|
-
)
|
261
|
-
prompt_tokens = tokenized_messages["input_ids"].shape[-1]
|
262
|
-
|
263
|
-
# Sent to the same device as model
|
264
|
-
tokenized_messages = {
|
265
|
-
k: v.to(self.model.device) for k, v in tokenized_messages.items()
|
266
|
-
}
|
267
|
-
|
268
|
-
# Truncate the key-value cache
|
269
|
-
self.truncate_kv_cache(tokenized_messages["input_ids"])
|
270
|
-
|
271
|
-
if (
|
272
|
-
"past_key_values" in self.past_key_values_kwargs
|
273
|
-
and self.past_key_values_kwargs["past_key_values"][0][0].shape[-2]
|
274
|
-
> tokenized_messages["input_ids"].shape[-1]
|
275
|
-
):
|
276
|
-
raise ValueError("Past key values are larger than the input_ids")
|
277
|
-
|
278
|
-
past_key_values = self.past_key_values_kwargs.get("past_key_values", None)
|
279
|
-
if past_key_values is not None:
|
280
|
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
281
|
-
|
282
|
-
generated_sequences = self.model.generate(
|
283
|
-
**tokenized_messages,
|
284
|
-
do_sample=True,
|
285
|
-
temperature=temperature,
|
286
|
-
max_new_tokens=max_tokens,
|
287
|
-
pad_token_id=self.pad_token_id,
|
288
|
-
return_dict_in_generate=True,
|
289
|
-
use_cache=True,
|
290
|
-
past_key_values=past_key_values,
|
291
|
-
return_legacy_cache=True,
|
292
|
-
)
|
293
|
-
# Cache the past key values
|
294
|
-
if self.use_hot_cache:
|
295
|
-
self.past_key_values_kwargs["past_key_values"] = (
|
296
|
-
generated_sequences.past_key_values
|
297
|
-
)
|
298
|
-
self.past_token_ids = generated_sequences.sequences
|
299
|
-
|
300
|
-
# Decode the output
|
301
|
-
text_responses = self.tokenizer.batch_decode(
|
302
|
-
generated_sequences.sequences[:, prompt_tokens:],
|
303
|
-
skip_special_tokens=False,
|
304
|
-
)
|
305
|
-
|
306
|
-
text_responses = [
|
307
|
-
text_response.replace("<|eot_id|>", "") for text_response in text_responses
|
308
|
-
]
|
309
|
-
|
310
|
-
_, total_tokens_used = generated_sequences.sequences.shape
|
311
|
-
return text_responses, total_tokens_used
|
312
|
-
|
313
|
-
|
314
|
-
class OpenAIGenerator:
|
315
|
-
def __init__(self, use_images=False, model_name="gpt-4o-mini"):
|
316
|
-
self.client = OpenAI()
|
317
|
-
self.use_images = use_images
|
318
|
-
self.model_name = model_name
|
319
|
-
|
320
|
-
def reset(self):
|
321
|
-
pass
|
322
|
-
|
323
|
-
def prepare_messages(
|
324
|
-
self,
|
325
|
-
history: History,
|
326
|
-
max_messages_window: int,
|
327
|
-
system_prompt: dict = None,
|
328
|
-
prompt_images: list = [],
|
329
|
-
) -> tuple[list[dict], list]:
|
330
|
-
"""
|
331
|
-
Prepare the image messages for the model
|
332
|
-
"""
|
333
|
-
message_window = history.dialogue_history[-max_messages_window:]
|
334
|
-
# remove the first assistant message if it is present
|
335
|
-
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
336
|
-
message_window = message_window[1:]
|
337
|
-
# add the system prompt if the first message is not a system message
|
338
|
-
if message_window[0]["role"] != "system" and system_prompt is not None:
|
339
|
-
message_window = [system_prompt] + message_window
|
340
|
-
|
341
|
-
if self.use_images:
|
342
|
-
image_list = prompt_images + history.images
|
343
|
-
img_idx = -1
|
344
|
-
seen_images = 0
|
345
|
-
# iterate through the messages in reverse order to assign images
|
346
|
-
for i in range(len(message_window) - 1, -1, -1):
|
347
|
-
new_content_list = []
|
348
|
-
for content in message_window[i]["content"]:
|
349
|
-
if content["type"] == "text":
|
350
|
-
new_content_list.append(content)
|
351
|
-
elif content["type"] == "image":
|
352
|
-
base64_image = numpy_to_base64(image_list[img_idx])
|
353
|
-
img_idx -= 1
|
354
|
-
seen_images + 1
|
355
|
-
new_content = {
|
356
|
-
"type": "image_url",
|
357
|
-
"image_url": {
|
358
|
-
"url": f"data:image/jpeg;base64,{base64_image}"
|
359
|
-
},
|
360
|
-
}
|
361
|
-
new_content_list.append(new_content)
|
362
|
-
message_window[i]["content"] = new_content_list
|
363
|
-
assert seen_images <= len(image_list), "Too many images"
|
364
|
-
|
365
|
-
return message_window, []
|
366
|
-
|
367
|
-
def generate_unconstrained(
|
368
|
-
self,
|
369
|
-
batch_messages: list[list[dict]],
|
370
|
-
max_tokens=256,
|
371
|
-
**kwargs,
|
372
|
-
) -> tuple[list[str], int]:
|
373
|
-
contents = []
|
374
|
-
tokens_used = 0
|
375
|
-
for messages in batch_messages:
|
376
|
-
response = self.client.chat.completions.create(
|
377
|
-
model=self.model_name,
|
378
|
-
messages=messages,
|
379
|
-
temperature=0.0,
|
380
|
-
max_tokens=max_tokens,
|
381
|
-
top_p=1,
|
382
|
-
frequency_penalty=0,
|
383
|
-
presence_penalty=0,
|
384
|
-
stop=["\n", "\n\n"],
|
385
|
-
)
|
386
|
-
content = response.choices[0].message.content
|
387
|
-
tokens_used += response.usage.total_tokens
|
388
|
-
contents.append(content)
|
389
|
-
return contents, tokens_used
|
390
|
-
|
391
|
-
|
392
|
-
class OAMGenerator:
|
393
|
-
def __init__(
|
394
|
-
self,
|
395
|
-
model_name,
|
396
|
-
):
|
397
|
-
self.model_name = model_name
|
398
|
-
logger.info("Loading model")
|
399
|
-
time_now = time.time()
|
400
|
-
self.model = PlancraftOAM.from_pretrained(model_name)
|
401
|
-
self.model.cuda()
|
402
|
-
logger.info(f"Model loaded in {time.time() - time_now:.2f} seconds")
|
403
|
-
# compile
|
404
|
-
time_now = time.time()
|
405
|
-
self.model = torch.compile(self.model)
|
406
|
-
logger.info(f"Model compiled in {time.time() - time_now:.2f} seconds")
|
407
|
-
|
408
|
-
self.model.eval()
|
409
|
-
self.model.tokenizer.pad_token = self.model.tokenizer.eos_token
|
410
|
-
self.model.tokenizer.truncation_side = "left"
|
411
|
-
|
412
|
-
def reset(self):
|
413
|
-
pass
|
414
|
-
|
415
|
-
def prepare_messages(
|
416
|
-
self,
|
417
|
-
history: History,
|
418
|
-
max_messages_window: int,
|
419
|
-
system_prompt: dict = None,
|
420
|
-
prompt_images: list = [],
|
421
|
-
) -> tuple[list[dict], list]:
|
422
|
-
"""
|
423
|
-
Prepare the messages using a history
|
424
|
-
"""
|
425
|
-
message_window = history.dialogue_history[-max_messages_window:]
|
426
|
-
# remove the first assistant message if it is present
|
427
|
-
if len(message_window) > 0 and message_window[0]["role"] == "assistant":
|
428
|
-
message_window = message_window[1:]
|
429
|
-
# add the system prompt if the first message is not a system message
|
430
|
-
if message_window[0]["role"] != "system" and system_prompt is not None:
|
431
|
-
message_window = [system_prompt] + message_window
|
432
|
-
|
433
|
-
image_window = []
|
434
|
-
|
435
|
-
image_list = prompt_images + history.images
|
436
|
-
image_count = 0
|
437
|
-
|
438
|
-
# iterate through the messages to count how many images are present
|
439
|
-
new_message_window = []
|
440
|
-
for m in message_window:
|
441
|
-
message_content = m["content"]
|
442
|
-
message_role = m["role"]
|
443
|
-
|
444
|
-
if "\ninventory:\n" in message_content:
|
445
|
-
message_content = (
|
446
|
-
message_content.split("\ninventory:\n")[0]
|
447
|
-
+ "\ninventory:<|inventory|>"
|
448
|
-
)
|
449
|
-
|
450
|
-
if "<|inventory|>" in message_content:
|
451
|
-
image_count += 1
|
452
|
-
|
453
|
-
new_message_window.append(
|
454
|
-
{"role": message_role, "content": message_content}
|
455
|
-
)
|
456
|
-
|
457
|
-
assert image_count <= len(image_list), "Too many images"
|
458
|
-
# get messages from end of queue
|
459
|
-
image_window = image_list[-image_count:]
|
460
|
-
image_window = [Image.fromarray(img) for img in image_window]
|
461
|
-
|
462
|
-
return new_message_window, image_window
|
463
|
-
|
464
|
-
@torch.inference_mode()
|
465
|
-
def generate_unconstrained(
|
466
|
-
self,
|
467
|
-
batch_messages: list[list[dict]],
|
468
|
-
images: list,
|
469
|
-
max_tokens: int = 256,
|
470
|
-
temperature=0.6,
|
471
|
-
**kwargs,
|
472
|
-
) -> tuple[list[str], int]:
|
473
|
-
"""
|
474
|
-
Generate unconstrained text based on the batch of messages.
|
475
|
-
"""
|
476
|
-
text_responses, total_tokens_used = self.model.generate(
|
477
|
-
batch_messages=batch_messages,
|
478
|
-
batch_images=images,
|
479
|
-
do_sample=True,
|
480
|
-
temperature=temperature,
|
481
|
-
max_new_tokens=max_tokens,
|
482
|
-
)
|
483
|
-
return text_responses, total_tokens_used
|