plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/__init__.py +0 -0
- plancraft/config.py +155 -0
- plancraft/environments/__init__.py +0 -0
- plancraft/environments/actions.py +218 -0
- plancraft/environments/env_real.py +316 -0
- plancraft/environments/env_symbolic.py +212 -0
- plancraft/environments/items.py +10 -0
- plancraft/environments/planner.py +109 -0
- plancraft/environments/recipes.py +542 -0
- plancraft/environments/sampler.py +224 -0
- plancraft/evaluator.py +273 -0
- plancraft/models/__init__.py +21 -0
- plancraft/models/act.py +184 -0
- plancraft/models/base.py +152 -0
- plancraft/models/bbox_model.py +492 -0
- plancraft/models/dummy.py +54 -0
- plancraft/models/few_shot_images/__init__.py +16 -0
- plancraft/models/generators.py +480 -0
- plancraft/models/oam.py +283 -0
- plancraft/models/oracle.py +265 -0
- plancraft/models/prompts.py +158 -0
- plancraft/models/react.py +93 -0
- plancraft/models/utils.py +289 -0
- plancraft/train/dataset.py +187 -0
- plancraft/utils.py +84 -0
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/METADATA +1 -1
- plancraft-0.1.3.dist-info/RECORD +30 -0
- plancraft-0.1.3.dist-info/top_level.txt +1 -0
- plancraft-0.1.2.dist-info/RECORD +0 -5
- plancraft-0.1.2.dist-info/top_level.txt +0 -1
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/LICENSE +0 -0
- {plancraft-0.1.2.dist-info → plancraft-0.1.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,224 @@
|
|
1
|
+
import math
|
2
|
+
import random
|
3
|
+
from collections import Counter
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from plancraft.environments.items import all_data, ALL_ITEMS
|
7
|
+
from plancraft.environments.recipes import RECIPES
|
8
|
+
from plancraft.environments.planner import optimal_planner, get_ancestors
|
9
|
+
|
10
|
+
|
11
|
+
MAX_STACK_SIZE = {}
|
12
|
+
for data_item in all_data["items"]:
|
13
|
+
if data_item["stackable"]:
|
14
|
+
MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
|
15
|
+
else:
|
16
|
+
MAX_STACK_SIZE[data_item["type"]] = 1
|
17
|
+
|
18
|
+
|
19
|
+
def sample_distractors(
|
20
|
+
exclude_set: set = None, num_distractors: int = 16
|
21
|
+
) -> dict[str, int]:
|
22
|
+
distractors = {}
|
23
|
+
while len(distractors) < num_distractors:
|
24
|
+
item = random.choice(ALL_ITEMS)
|
25
|
+
if exclude_set is not None and item in exclude_set:
|
26
|
+
continue
|
27
|
+
count = random.randint(1, MAX_STACK_SIZE[item])
|
28
|
+
distractors[item] = count
|
29
|
+
return distractors
|
30
|
+
|
31
|
+
|
32
|
+
def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
|
33
|
+
# slots available outside of crafting interface
|
34
|
+
available_slots = list(range(10, 46))
|
35
|
+
random.shuffle(available_slots)
|
36
|
+
inventory_list = []
|
37
|
+
|
38
|
+
for item, total_count in inventory.items():
|
39
|
+
while total_count > 0:
|
40
|
+
if len(available_slots) == 0:
|
41
|
+
print("Not enough slots available")
|
42
|
+
break
|
43
|
+
slot = available_slots.pop()
|
44
|
+
count_in_slot = min(total_count, MAX_STACK_SIZE[item])
|
45
|
+
inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
|
46
|
+
total_count -= count_in_slot
|
47
|
+
|
48
|
+
return inventory_list
|
49
|
+
|
50
|
+
|
51
|
+
def sample_recipes(
|
52
|
+
target: str,
|
53
|
+
overall_exclude_set: set,
|
54
|
+
target_count: int = 1,
|
55
|
+
current_depth=0,
|
56
|
+
max_depth=20,
|
57
|
+
) -> tuple[set, set]:
|
58
|
+
# stop if the depth is too high
|
59
|
+
if current_depth > max_depth:
|
60
|
+
return {}, overall_exclude_set
|
61
|
+
|
62
|
+
# get all the recipes that can craft the target
|
63
|
+
overall_exclude_set.update([target])
|
64
|
+
local_exclude_set = set()
|
65
|
+
random_recipes = []
|
66
|
+
for r in RECIPES[target]:
|
67
|
+
recipe_inputs, exclude_set = r.sample_inputs()
|
68
|
+
# if inputs are already in the exclude set, skip this recipe (ensures no cycle)
|
69
|
+
if exclude_set.intersection(overall_exclude_set):
|
70
|
+
return {}, overall_exclude_set
|
71
|
+
local_exclude_set.update(exclude_set)
|
72
|
+
random_recipes.append((r, recipe_inputs))
|
73
|
+
|
74
|
+
overall_exclude_set |= local_exclude_set
|
75
|
+
|
76
|
+
# no recipes found
|
77
|
+
if len(random_recipes) == 0:
|
78
|
+
return {}, overall_exclude_set
|
79
|
+
|
80
|
+
# sample a random recipe
|
81
|
+
random_recipe = random.choice(random_recipes)
|
82
|
+
recipe, start_inputs = random_recipe
|
83
|
+
|
84
|
+
# recipe will not produce enough
|
85
|
+
if recipe.result.count < target_count:
|
86
|
+
# must do recipe X times
|
87
|
+
recipe_multiplier = math.ceil(target_count / recipe.result.count)
|
88
|
+
start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
|
89
|
+
|
90
|
+
for input_item in list(start_inputs.keys()):
|
91
|
+
# randomize depth first search to end early
|
92
|
+
if random.choice([True, False]):
|
93
|
+
continue
|
94
|
+
|
95
|
+
children_recipe_inputs, updated_exclude_set = sample_recipes(
|
96
|
+
target=input_item,
|
97
|
+
overall_exclude_set=overall_exclude_set,
|
98
|
+
target_count=start_inputs[input_item],
|
99
|
+
current_depth=current_depth + 1,
|
100
|
+
)
|
101
|
+
if len(children_recipe_inputs) == 0:
|
102
|
+
continue
|
103
|
+
|
104
|
+
overall_exclude_set.update(updated_exclude_set)
|
105
|
+
|
106
|
+
# remove recipe input item since we are crafting it
|
107
|
+
start_inputs[input_item] = 0
|
108
|
+
|
109
|
+
# add the children recipe inputs
|
110
|
+
for item, count in children_recipe_inputs.items():
|
111
|
+
start_inputs[item] = start_inputs.get(item, 0) + count
|
112
|
+
|
113
|
+
overall_exclude_set = overall_exclude_set - {None}
|
114
|
+
start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
|
115
|
+
|
116
|
+
return start_inputs, overall_exclude_set
|
117
|
+
|
118
|
+
|
119
|
+
def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
|
120
|
+
ancestors = set(get_ancestors(target))
|
121
|
+
possible_items = set(inventory.keys())
|
122
|
+
items_to_remove = list(ancestors.intersection(possible_items))
|
123
|
+
num_items = random.randint(1, len(items_to_remove))
|
124
|
+
for item in random.sample(items_to_remove, num_items):
|
125
|
+
count_to_remove = random.randint(1, inventory[item])
|
126
|
+
inventory[item] -= count_to_remove
|
127
|
+
if inventory[item] == 0:
|
128
|
+
del inventory[item]
|
129
|
+
return inventory
|
130
|
+
|
131
|
+
|
132
|
+
def construct_example(
|
133
|
+
target: str,
|
134
|
+
num_distractors: 16,
|
135
|
+
impossible=False,
|
136
|
+
) -> list[dict]:
|
137
|
+
"""
|
138
|
+
For a given target object, number of distractors, and impossible flag
|
139
|
+
Return a dictionary with the start inventory for the crafting task
|
140
|
+
|
141
|
+
The crafting task should be to craft the target, the inventory should contain
|
142
|
+
the resources required for the recipe to be crafted.
|
143
|
+
|
144
|
+
The number of distractors are how many random items should be added to the inventory.
|
145
|
+
|
146
|
+
If impossible is True, the target item should not be craftable with the given inventory.
|
147
|
+
"""
|
148
|
+
|
149
|
+
# sample the recipe
|
150
|
+
inventory, overall_exclude_set = sample_recipes(target, set())
|
151
|
+
if impossible:
|
152
|
+
# if impossible then remove one or more items from the inventory
|
153
|
+
inventory = remove_ancestor_items(
|
154
|
+
target,
|
155
|
+
inventory,
|
156
|
+
)
|
157
|
+
|
158
|
+
# add distractors to the inventory
|
159
|
+
distractors = sample_distractors(overall_exclude_set, num_distractors)
|
160
|
+
inventory.update(distractors)
|
161
|
+
|
162
|
+
optimal_path = optimal_planner(target, inventory)
|
163
|
+
# @TODO this is a hack to ensure that we don't have impossible examples
|
164
|
+
while optimal_path is not None and impossible:
|
165
|
+
inventory = remove_ancestor_items(target, inventory)
|
166
|
+
optimal_path = optimal_planner(target, inventory)
|
167
|
+
|
168
|
+
# assign to slots
|
169
|
+
inventory_list = assign_to_slots(inventory)
|
170
|
+
example = {
|
171
|
+
"inventory": inventory,
|
172
|
+
"slotted_inventory": inventory_list,
|
173
|
+
"target": target,
|
174
|
+
"num_distractors": num_distractors,
|
175
|
+
"impossible": impossible,
|
176
|
+
}
|
177
|
+
# either impossible and no path or not impossible and path exists
|
178
|
+
assert (impossible and optimal_path is None) or (
|
179
|
+
not impossible and optimal_path is not None
|
180
|
+
)
|
181
|
+
|
182
|
+
if not impossible:
|
183
|
+
example["optimal_path_length"] = len(optimal_path)
|
184
|
+
example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
|
185
|
+
example["inventory_trace"] = [i for (r, i) in optimal_path]
|
186
|
+
items_used, unique_items_used = calculate_stats_from_inventory_trace(
|
187
|
+
[example["inventory"]] + example["inventory_trace"]
|
188
|
+
)
|
189
|
+
example["items_used"] = items_used
|
190
|
+
example["unique_items_used"] = unique_items_used
|
191
|
+
|
192
|
+
return example
|
193
|
+
|
194
|
+
|
195
|
+
def calculate_stats_from_inventory_trace(
|
196
|
+
inventory_trace: list[dict],
|
197
|
+
) -> tuple[int, int]:
|
198
|
+
total_items_used = 0
|
199
|
+
total_unique_items_used = 0
|
200
|
+
|
201
|
+
for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
|
202
|
+
diff = Counter(a) - Counter(b)
|
203
|
+
total_items_used += sum(diff.values())
|
204
|
+
total_unique_items_used += len(diff)
|
205
|
+
|
206
|
+
return total_items_used, total_unique_items_used
|
207
|
+
|
208
|
+
|
209
|
+
def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
|
210
|
+
random.seed(seed)
|
211
|
+
np.random.seed(seed)
|
212
|
+
|
213
|
+
dataset = []
|
214
|
+
for recipe_target in list(RECIPES.keys()):
|
215
|
+
if len(RECIPES[recipe_target]) == 0:
|
216
|
+
continue
|
217
|
+
for num_distractors in distractors:
|
218
|
+
for _ in range(num_examples):
|
219
|
+
example = construct_example(
|
220
|
+
target=recipe_target, num_distractors=num_distractors
|
221
|
+
)
|
222
|
+
dataset.append(example)
|
223
|
+
|
224
|
+
return dataset
|
plancraft/evaluator.py
ADDED
@@ -0,0 +1,273 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import string
|
5
|
+
import time
|
6
|
+
|
7
|
+
import imageio
|
8
|
+
import pandas as pd
|
9
|
+
import torch
|
10
|
+
import wandb
|
11
|
+
from loguru import logger
|
12
|
+
from tqdm import tqdm
|
13
|
+
|
14
|
+
from plancraft.config import EvalConfig, PlancraftExample
|
15
|
+
from plancraft.environments.actions import StopAction
|
16
|
+
from plancraft.environments.env_real import RealPlancraft
|
17
|
+
from plancraft.environments.env_symbolic import SymbolicPlancraft
|
18
|
+
from plancraft.models import get_model
|
19
|
+
|
20
|
+
wandb.require("core")
|
21
|
+
|
22
|
+
|
23
|
+
class Evaluator:
|
24
|
+
"""
|
25
|
+
The evaluator class handles the environment loop and model interaction
|
26
|
+
|
27
|
+
The environment is created based on the configuration and the examples are loaded from the dataset.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, cfg: EvalConfig):
|
31
|
+
self.cfg = cfg
|
32
|
+
self.output_dir = (
|
33
|
+
f"{cfg.plancraft.output_dir}/{self.evaluator_name()}/{cfg.plancraft.split}"
|
34
|
+
)
|
35
|
+
self.generation_number = 0
|
36
|
+
|
37
|
+
self.examples: list[PlancraftExample] = self.load_dataset(cfg.plancraft.split)
|
38
|
+
|
39
|
+
self.environment = self.create_env(cfg)
|
40
|
+
self.model = get_model(cfg)
|
41
|
+
|
42
|
+
self.record_frames = not (cfg.plancraft.environment.symbolic)
|
43
|
+
|
44
|
+
# no_op action
|
45
|
+
self.no_op = self.environment.action_space.no_op()
|
46
|
+
|
47
|
+
def evaluator_name(self) -> str:
|
48
|
+
symb_str = "real"
|
49
|
+
if self.cfg.plancraft.environment.symbolic:
|
50
|
+
symb_str = "symb"
|
51
|
+
|
52
|
+
if self.cfg.plancraft.use_maskrcnn:
|
53
|
+
symb_str += "_mrcnn"
|
54
|
+
|
55
|
+
model_name = self.cfg.plancraft.model.split("/")[-1]
|
56
|
+
if self.cfg.plancraft.adapter != "":
|
57
|
+
model_name = self.cfg.plancraft.adapter.split("/")[-1]
|
58
|
+
|
59
|
+
mode = self.cfg.plancraft.mode
|
60
|
+
if mode in ["dummy", "oracle"]:
|
61
|
+
return f"{mode}_{symb_str}"
|
62
|
+
|
63
|
+
actions = "|".join(self.cfg.plancraft.valid_actions)
|
64
|
+
return f"{self.cfg.plancraft.mode}_{symb_str}_{model_name}_{actions}"
|
65
|
+
|
66
|
+
def save_results_dict(self, example: PlancraftExample, results_dict: dict):
|
67
|
+
output_dir = f"{self.output_dir}/{self.generation_number}"
|
68
|
+
os.makedirs(output_dir, exist_ok=True)
|
69
|
+
json_path = f"{output_dir}/{example.id}.json"
|
70
|
+
with open(json_path, "w") as f:
|
71
|
+
json.dump(results_dict, f, indent=4)
|
72
|
+
wandb.save(json_path, policy="now")
|
73
|
+
|
74
|
+
def save_images(self, example: PlancraftExample, frames: list):
|
75
|
+
if len(frames) == 0:
|
76
|
+
return
|
77
|
+
output_dir = f"{self.output_dir}/{self.generation_number}"
|
78
|
+
os.makedirs(output_dir, exist_ok=True)
|
79
|
+
imageio.mimsave(f"{output_dir}/{example.id}.gif", frames)
|
80
|
+
# upload to wandb
|
81
|
+
wandb.save(f"{output_dir}/{example.id}.gif", policy="now")
|
82
|
+
|
83
|
+
def load_results_dict(self, example: PlancraftExample) -> dict:
|
84
|
+
path = f"{self.output_dir}/{self.generation_number}/{example.id}.json"
|
85
|
+
if not os.path.exists(path) or not self.cfg.plancraft.resume:
|
86
|
+
return None
|
87
|
+
with open(path, "r") as f:
|
88
|
+
return json.load(f)
|
89
|
+
|
90
|
+
def create_env(self, cfg: EvalConfig) -> RealPlancraft | SymbolicPlancraft:
|
91
|
+
if cfg.plancraft.environment.symbolic:
|
92
|
+
return SymbolicPlancraft(inventory=[])
|
93
|
+
return RealPlancraft(
|
94
|
+
inventory=[],
|
95
|
+
symbolic_action_space=cfg.plancraft.environment.symbolic_action_space,
|
96
|
+
symbolic_observation_space=cfg.plancraft.environment.symbolic_observation_space,
|
97
|
+
preferred_spawn_biome=cfg.plancraft.environment.preferred_spawn_biome,
|
98
|
+
resolution=cfg.plancraft.environment.resolution,
|
99
|
+
)
|
100
|
+
|
101
|
+
def close(self):
|
102
|
+
self.environment.close()
|
103
|
+
|
104
|
+
def load_dataset(self, dataset_split: str) -> list[PlancraftExample]:
|
105
|
+
with open(f"data/{dataset_split}.json", "r") as f:
|
106
|
+
dataset = json.load(f)
|
107
|
+
return [PlancraftExample(**example) for example in dataset]
|
108
|
+
|
109
|
+
def reset(
|
110
|
+
self,
|
111
|
+
example: PlancraftExample,
|
112
|
+
):
|
113
|
+
current_inventory = example.slotted_inventory
|
114
|
+
self.environment.fast_reset(new_inventory=current_inventory)
|
115
|
+
# do a no op to an initial observation
|
116
|
+
obs, _, _, _ = self.environment.step(self.no_op)
|
117
|
+
# assert that the inventory is correct
|
118
|
+
if "inventory" in obs:
|
119
|
+
for item in current_inventory:
|
120
|
+
slot = item["slot"]
|
121
|
+
if (
|
122
|
+
obs["inventory"][slot]["type"] != item["type"]
|
123
|
+
or obs["inventory"][slot]["quantity"] != item["quantity"]
|
124
|
+
) and item["type"] != "air":
|
125
|
+
logger.warning(f"Inventory does not match expected for slot {slot}")
|
126
|
+
logger.warning(f"Expected {item}")
|
127
|
+
logger.warning(f"Got {obs['inventory'][slot]}")
|
128
|
+
# try again
|
129
|
+
self.reset(example)
|
130
|
+
|
131
|
+
objective = f"Craft an item of type: {example.target}"
|
132
|
+
self.model.reset_history(objective=objective)
|
133
|
+
|
134
|
+
def check_done(self, inventory: list[dict[str, int]], target: str):
|
135
|
+
"""
|
136
|
+
Check that target object is obtained
|
137
|
+
"""
|
138
|
+
for item in inventory:
|
139
|
+
if target == item["type"]:
|
140
|
+
# ensure item is taken out of crafting slot
|
141
|
+
if "slot" in item and item["slot"] != 0:
|
142
|
+
return True
|
143
|
+
if "index" in item and item["index"] != 0:
|
144
|
+
return True
|
145
|
+
return False
|
146
|
+
|
147
|
+
@torch.no_grad()
|
148
|
+
def eval_all_examples(self, progress_bar=False) -> list:
|
149
|
+
results = []
|
150
|
+
action = self.no_op.copy()
|
151
|
+
|
152
|
+
pbar = tqdm(
|
153
|
+
total=len(self.examples),
|
154
|
+
disable=not progress_bar,
|
155
|
+
)
|
156
|
+
correct = 0
|
157
|
+
count = 0
|
158
|
+
|
159
|
+
for example in self.examples:
|
160
|
+
if resume_result := self.load_results_dict(example):
|
161
|
+
pbar.update(self.cfg.plancraft.max_steps)
|
162
|
+
results.append(resume_result)
|
163
|
+
continue
|
164
|
+
|
165
|
+
success = False
|
166
|
+
|
167
|
+
self.reset(example)
|
168
|
+
action = self.no_op.copy()
|
169
|
+
|
170
|
+
while (
|
171
|
+
not self.model.history.check_stuck()
|
172
|
+
and self.model.history.num_steps < self.cfg.plancraft.max_steps
|
173
|
+
):
|
174
|
+
# if the action is stop then we end the episode
|
175
|
+
if isinstance(action, StopAction):
|
176
|
+
# if the action is stop and task is impossible then success
|
177
|
+
# otherwise we should not have stopped
|
178
|
+
success = example.impossible
|
179
|
+
break
|
180
|
+
|
181
|
+
# step action
|
182
|
+
observation, _, _, _ = self.environment.step(action)
|
183
|
+
|
184
|
+
# check if the episode is done
|
185
|
+
success = self.check_done(observation["inventory"], example.target)
|
186
|
+
# exit if success
|
187
|
+
if success:
|
188
|
+
break
|
189
|
+
|
190
|
+
# predict next action
|
191
|
+
action = self.model.step(observation)
|
192
|
+
|
193
|
+
# save results and reset
|
194
|
+
result = {
|
195
|
+
"success": success,
|
196
|
+
"recipe_type": example.recipe_type,
|
197
|
+
"number_of_steps": self.model.history.num_steps,
|
198
|
+
"model_trace": self.model.history.trace(),
|
199
|
+
"example_id": example.id,
|
200
|
+
"impossible": example.impossible,
|
201
|
+
}
|
202
|
+
results.append(result)
|
203
|
+
self.save_results_dict(example, result)
|
204
|
+
self.save_images(example, self.model.history.images)
|
205
|
+
|
206
|
+
correct += int(result["success"])
|
207
|
+
count += 1
|
208
|
+
|
209
|
+
acc = correct / count
|
210
|
+
pbar.set_postfix(correct=correct, count=count, acc=acc)
|
211
|
+
pbar.update(1)
|
212
|
+
|
213
|
+
return results
|
214
|
+
|
215
|
+
def eval_all(self):
|
216
|
+
logger.info(
|
217
|
+
f"Running evaluation over {len(self.examples)} examples {self.cfg.plancraft.num_generations} times."
|
218
|
+
)
|
219
|
+
run_name = (
|
220
|
+
f"{self.evaluator_name()} {self.cfg.plancraft.split}".replace(" ", "_")
|
221
|
+
.replace(".", "_")
|
222
|
+
.strip()
|
223
|
+
)
|
224
|
+
|
225
|
+
for n in range(self.cfg.plancraft.num_generations):
|
226
|
+
logger.info(f"Generation {n+1}/{self.cfg.plancraft.num_generations}")
|
227
|
+
run_id = "".join(random.choices(string.ascii_lowercase, k=5))
|
228
|
+
generation_run_name = run_name + f"_{run_id}"
|
229
|
+
|
230
|
+
wandb.init(
|
231
|
+
name=generation_run_name,
|
232
|
+
project=self.cfg.wandb.project,
|
233
|
+
entity=self.cfg.wandb.entity,
|
234
|
+
mode=self.cfg.wandb.mode,
|
235
|
+
group=self.cfg.plancraft.model,
|
236
|
+
job_type=self.cfg.plancraft.mode,
|
237
|
+
config=self.cfg.model_dump(),
|
238
|
+
)
|
239
|
+
time_now = time.time()
|
240
|
+
|
241
|
+
results_list = self.eval_all_examples(progress_bar=True)
|
242
|
+
|
243
|
+
results_df = pd.DataFrame(results_list)
|
244
|
+
|
245
|
+
output = {
|
246
|
+
"avg_success_rate": results_df["success"].mean(),
|
247
|
+
"avg_number_of_steps": results_df["number_of_steps"].mean(),
|
248
|
+
"avg_num_tokens_used": results_df["model_trace"]
|
249
|
+
.apply(pd.Series)["tokens_used"]
|
250
|
+
.mean(),
|
251
|
+
}
|
252
|
+
|
253
|
+
# calculate success rate for each recipe type
|
254
|
+
recipe_types = results_df["recipe_type"].unique()
|
255
|
+
for recipe_type in recipe_types:
|
256
|
+
mask = results_df["recipe_type"] == recipe_type
|
257
|
+
success_rate = results_df[mask]["success"].mean()
|
258
|
+
output[f"{recipe_type}_success_rate"] = success_rate
|
259
|
+
|
260
|
+
time_elapsed = time.time() - time_now
|
261
|
+
logger.info(f"Time elapsed: {time_elapsed:.2f}s")
|
262
|
+
|
263
|
+
logger.info(output)
|
264
|
+
wandb.log(output)
|
265
|
+
table = wandb.Table(
|
266
|
+
dataframe=results_df[["success", "number_of_steps", "example_id"]]
|
267
|
+
)
|
268
|
+
wandb.log({"results": table})
|
269
|
+
wandb.finish()
|
270
|
+
|
271
|
+
self.generation_number += 1
|
272
|
+
|
273
|
+
logger.info("Done")
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from plancraft.models.base import ABCModel
|
2
|
+
|
3
|
+
from plancraft.config import EvalConfig
|
4
|
+
from plancraft.models.dummy import DummyModel
|
5
|
+
from plancraft.models.react import ReactModel
|
6
|
+
from plancraft.models.oracle import OracleModel
|
7
|
+
from plancraft.models.act import ActModel
|
8
|
+
|
9
|
+
|
10
|
+
def get_model(cfg: EvalConfig) -> ABCModel:
|
11
|
+
"""
|
12
|
+
Factory get model (default: ReactModel)
|
13
|
+
"""
|
14
|
+
if cfg.plancraft.mode == "dummy":
|
15
|
+
return DummyModel(cfg)
|
16
|
+
elif cfg.plancraft.mode == "oracle":
|
17
|
+
return OracleModel(cfg)
|
18
|
+
elif cfg.plancraft.mode == "act":
|
19
|
+
return ActModel(cfg)
|
20
|
+
else:
|
21
|
+
return ReactModel(cfg)
|
plancraft/models/act.py
ADDED
@@ -0,0 +1,184 @@
|
|
1
|
+
import copy
|
2
|
+
import torch
|
3
|
+
from dotenv import load_dotenv
|
4
|
+
|
5
|
+
from plancraft.config import EvalConfig
|
6
|
+
from plancraft.environments.actions import (
|
7
|
+
NoOp,
|
8
|
+
StopAction,
|
9
|
+
SymbolicAction,
|
10
|
+
)
|
11
|
+
from plancraft.models.base import ABCModel, History
|
12
|
+
from plancraft.models.bbox_model import IntegratedBoundingBoxModel
|
13
|
+
from plancraft.models.few_shot_images import load_prompt_images
|
14
|
+
from plancraft.models.generators import (
|
15
|
+
OAMGenerator,
|
16
|
+
OpenAIGenerator,
|
17
|
+
TransformersGenerator,
|
18
|
+
)
|
19
|
+
from plancraft.models.prompts import get_prompt_example, get_system_prompt
|
20
|
+
from plancraft.models.utils import (
|
21
|
+
convert_observation_to_message,
|
22
|
+
parse_content_response,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
load_dotenv()
|
27
|
+
|
28
|
+
|
29
|
+
class ActModel(ABCModel):
|
30
|
+
"""
|
31
|
+
Model that does action without thinking step
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self, cfg: EvalConfig):
|
35
|
+
assert (
|
36
|
+
cfg.plancraft.environment.symbolic_action_space
|
37
|
+
), "Real action space unsupported"
|
38
|
+
self.cfg = cfg
|
39
|
+
self.env_is_multimodal = not cfg.plancraft.environment.symbolic
|
40
|
+
self.use_maskrcnn = cfg.plancraft.use_maskrcnn
|
41
|
+
self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
|
42
|
+
self.use_text_inventory = cfg.plancraft.use_text_inventory
|
43
|
+
self.use_images = cfg.plancraft.use_images
|
44
|
+
|
45
|
+
self.bbox_model = None
|
46
|
+
if self.use_maskrcnn:
|
47
|
+
assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
|
48
|
+
self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
|
49
|
+
"gautierdag/plancraft-maskrcnn"
|
50
|
+
)
|
51
|
+
self.bbox_model.eval()
|
52
|
+
if torch.cuda.is_available():
|
53
|
+
self.bbox_model.cuda()
|
54
|
+
# MaskRCNN is not multimodal model but a separate model
|
55
|
+
|
56
|
+
self.few_shot = cfg.plancraft.few_shot
|
57
|
+
self.use_system_prompt = cfg.plancraft.system_prompt
|
58
|
+
self.max_invalid_actions = 3
|
59
|
+
|
60
|
+
# underlying language model
|
61
|
+
if "gpt-4o" in cfg.plancraft.model:
|
62
|
+
self.use_multimodal_content_format = True
|
63
|
+
self.llm = OpenAIGenerator(
|
64
|
+
use_images=self.use_images, model_name=cfg.plancraft.model
|
65
|
+
)
|
66
|
+
elif "oam" in cfg.plancraft.model:
|
67
|
+
self.llm = OAMGenerator(model_name=cfg.plancraft.model)
|
68
|
+
else:
|
69
|
+
# model is transformers based
|
70
|
+
self.llm = TransformersGenerator(
|
71
|
+
model_name=cfg.plancraft.model,
|
72
|
+
tokenizer_name=cfg.plancraft.tokenizer,
|
73
|
+
quantize=cfg.plancraft.quantize,
|
74
|
+
use_hot_cache=cfg.plancraft.hot_cache,
|
75
|
+
adapter_name=cfg.plancraft.adapter,
|
76
|
+
)
|
77
|
+
|
78
|
+
self.prompt_images = []
|
79
|
+
|
80
|
+
self.valid_actions = cfg.plancraft.valid_actions
|
81
|
+
self.system_prompt_text = get_system_prompt(self.valid_actions)
|
82
|
+
|
83
|
+
examples = get_prompt_example(
|
84
|
+
self.valid_actions,
|
85
|
+
use_text_inventory=self.use_text_inventory,
|
86
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
87
|
+
use_images=self.use_images,
|
88
|
+
)
|
89
|
+
if self.env_is_multimodal and self.use_images:
|
90
|
+
self.prompt_images = load_prompt_images()
|
91
|
+
|
92
|
+
if self.use_multimodal_content_format:
|
93
|
+
self.system_prompt = {
|
94
|
+
"role": "system",
|
95
|
+
"content": [
|
96
|
+
{"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
|
97
|
+
],
|
98
|
+
}
|
99
|
+
else:
|
100
|
+
self.system_prompt = {
|
101
|
+
"role": "system",
|
102
|
+
"content": copy.deepcopy(self.system_prompt_text),
|
103
|
+
}
|
104
|
+
|
105
|
+
if not self.few_shot:
|
106
|
+
examples = []
|
107
|
+
if not self.use_system_prompt:
|
108
|
+
self.system_prompt = None
|
109
|
+
|
110
|
+
self.history = History(
|
111
|
+
initial_dialogue=examples,
|
112
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
113
|
+
)
|
114
|
+
|
115
|
+
self.max_messages_window = cfg.plancraft.max_message_window
|
116
|
+
self.kv_cache = None
|
117
|
+
|
118
|
+
def reset_history(
|
119
|
+
self,
|
120
|
+
objective: str,
|
121
|
+
):
|
122
|
+
examples = []
|
123
|
+
if self.few_shot:
|
124
|
+
examples = get_prompt_example(
|
125
|
+
self.valid_actions,
|
126
|
+
use_text_inventory=self.use_text_inventory,
|
127
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
128
|
+
use_images=self.use_images,
|
129
|
+
)
|
130
|
+
|
131
|
+
self.history.reset(objective=objective, initial_dialogue=examples)
|
132
|
+
self.llm.reset()
|
133
|
+
|
134
|
+
def step(self, observation: dict) -> SymbolicAction | StopAction:
|
135
|
+
self.history.add_observation_to_history(observation)
|
136
|
+
|
137
|
+
# add observation to history
|
138
|
+
observation_message = convert_observation_to_message(
|
139
|
+
observation,
|
140
|
+
objective=self.history.objective,
|
141
|
+
bbox_model=self.bbox_model,
|
142
|
+
oam_model="oam" in self.llm.model_name,
|
143
|
+
use_text_inventory=self.use_text_inventory,
|
144
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
145
|
+
use_images=self.use_images,
|
146
|
+
)
|
147
|
+
self.history.add_message_to_history(content=observation_message, role="user")
|
148
|
+
|
149
|
+
# Iterate until valid action
|
150
|
+
i = 0
|
151
|
+
while i < self.max_invalid_actions:
|
152
|
+
# add observation to history
|
153
|
+
message_window, image_window = self.llm.prepare_messages(
|
154
|
+
history=self.history,
|
155
|
+
max_messages_window=self.max_messages_window,
|
156
|
+
system_prompt=self.system_prompt,
|
157
|
+
prompt_images=self.prompt_images,
|
158
|
+
)
|
159
|
+
action_messages, action_token_used = self.llm.generate_unconstrained(
|
160
|
+
batch_messages=[message_window],
|
161
|
+
images=[image_window],
|
162
|
+
)
|
163
|
+
self.history.tokens_used += action_token_used
|
164
|
+
|
165
|
+
action_message = action_messages[0].split("\n")[0].strip()
|
166
|
+
|
167
|
+
self.history.add_message_to_history(
|
168
|
+
content=action_message, role="assistant"
|
169
|
+
)
|
170
|
+
response = parse_content_response(
|
171
|
+
action_message, valid_actions=self.valid_actions
|
172
|
+
)
|
173
|
+
if not isinstance(response, str):
|
174
|
+
# valid action
|
175
|
+
self.history.add_action_to_history(response)
|
176
|
+
return response
|
177
|
+
|
178
|
+
self.history.add_message_to_history(
|
179
|
+
content=response,
|
180
|
+
)
|
181
|
+
i += 1
|
182
|
+
|
183
|
+
# if no action is found after max_invalid_actions, default to useless move action
|
184
|
+
return NoOp()
|