plancraft 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,224 @@
1
+ import math
2
+ import random
3
+ from collections import Counter
4
+
5
+ import numpy as np
6
+ from plancraft.environments.items import all_data, ALL_ITEMS
7
+ from plancraft.environments.recipes import RECIPES
8
+ from plancraft.environments.planner import optimal_planner, get_ancestors
9
+
10
+
11
+ MAX_STACK_SIZE = {}
12
+ for data_item in all_data["items"]:
13
+ if data_item["stackable"]:
14
+ MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
15
+ else:
16
+ MAX_STACK_SIZE[data_item["type"]] = 1
17
+
18
+
19
+ def sample_distractors(
20
+ exclude_set: set = None, num_distractors: int = 16
21
+ ) -> dict[str, int]:
22
+ distractors = {}
23
+ while len(distractors) < num_distractors:
24
+ item = random.choice(ALL_ITEMS)
25
+ if exclude_set is not None and item in exclude_set:
26
+ continue
27
+ count = random.randint(1, MAX_STACK_SIZE[item])
28
+ distractors[item] = count
29
+ return distractors
30
+
31
+
32
+ def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
33
+ # slots available outside of crafting interface
34
+ available_slots = list(range(10, 46))
35
+ random.shuffle(available_slots)
36
+ inventory_list = []
37
+
38
+ for item, total_count in inventory.items():
39
+ while total_count > 0:
40
+ if len(available_slots) == 0:
41
+ print("Not enough slots available")
42
+ break
43
+ slot = available_slots.pop()
44
+ count_in_slot = min(total_count, MAX_STACK_SIZE[item])
45
+ inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
46
+ total_count -= count_in_slot
47
+
48
+ return inventory_list
49
+
50
+
51
+ def sample_recipes(
52
+ target: str,
53
+ overall_exclude_set: set,
54
+ target_count: int = 1,
55
+ current_depth=0,
56
+ max_depth=20,
57
+ ) -> tuple[set, set]:
58
+ # stop if the depth is too high
59
+ if current_depth > max_depth:
60
+ return {}, overall_exclude_set
61
+
62
+ # get all the recipes that can craft the target
63
+ overall_exclude_set.update([target])
64
+ local_exclude_set = set()
65
+ random_recipes = []
66
+ for r in RECIPES[target]:
67
+ recipe_inputs, exclude_set = r.sample_inputs()
68
+ # if inputs are already in the exclude set, skip this recipe (ensures no cycle)
69
+ if exclude_set.intersection(overall_exclude_set):
70
+ return {}, overall_exclude_set
71
+ local_exclude_set.update(exclude_set)
72
+ random_recipes.append((r, recipe_inputs))
73
+
74
+ overall_exclude_set |= local_exclude_set
75
+
76
+ # no recipes found
77
+ if len(random_recipes) == 0:
78
+ return {}, overall_exclude_set
79
+
80
+ # sample a random recipe
81
+ random_recipe = random.choice(random_recipes)
82
+ recipe, start_inputs = random_recipe
83
+
84
+ # recipe will not produce enough
85
+ if recipe.result.count < target_count:
86
+ # must do recipe X times
87
+ recipe_multiplier = math.ceil(target_count / recipe.result.count)
88
+ start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
89
+
90
+ for input_item in list(start_inputs.keys()):
91
+ # randomize depth first search to end early
92
+ if random.choice([True, False]):
93
+ continue
94
+
95
+ children_recipe_inputs, updated_exclude_set = sample_recipes(
96
+ target=input_item,
97
+ overall_exclude_set=overall_exclude_set,
98
+ target_count=start_inputs[input_item],
99
+ current_depth=current_depth + 1,
100
+ )
101
+ if len(children_recipe_inputs) == 0:
102
+ continue
103
+
104
+ overall_exclude_set.update(updated_exclude_set)
105
+
106
+ # remove recipe input item since we are crafting it
107
+ start_inputs[input_item] = 0
108
+
109
+ # add the children recipe inputs
110
+ for item, count in children_recipe_inputs.items():
111
+ start_inputs[item] = start_inputs.get(item, 0) + count
112
+
113
+ overall_exclude_set = overall_exclude_set - {None}
114
+ start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
115
+
116
+ return start_inputs, overall_exclude_set
117
+
118
+
119
+ def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
120
+ ancestors = set(get_ancestors(target))
121
+ possible_items = set(inventory.keys())
122
+ items_to_remove = list(ancestors.intersection(possible_items))
123
+ num_items = random.randint(1, len(items_to_remove))
124
+ for item in random.sample(items_to_remove, num_items):
125
+ count_to_remove = random.randint(1, inventory[item])
126
+ inventory[item] -= count_to_remove
127
+ if inventory[item] == 0:
128
+ del inventory[item]
129
+ return inventory
130
+
131
+
132
+ def construct_example(
133
+ target: str,
134
+ num_distractors: 16,
135
+ impossible=False,
136
+ ) -> list[dict]:
137
+ """
138
+ For a given target object, number of distractors, and impossible flag
139
+ Return a dictionary with the start inventory for the crafting task
140
+
141
+ The crafting task should be to craft the target, the inventory should contain
142
+ the resources required for the recipe to be crafted.
143
+
144
+ The number of distractors are how many random items should be added to the inventory.
145
+
146
+ If impossible is True, the target item should not be craftable with the given inventory.
147
+ """
148
+
149
+ # sample the recipe
150
+ inventory, overall_exclude_set = sample_recipes(target, set())
151
+ if impossible:
152
+ # if impossible then remove one or more items from the inventory
153
+ inventory = remove_ancestor_items(
154
+ target,
155
+ inventory,
156
+ )
157
+
158
+ # add distractors to the inventory
159
+ distractors = sample_distractors(overall_exclude_set, num_distractors)
160
+ inventory.update(distractors)
161
+
162
+ optimal_path = optimal_planner(target, inventory)
163
+ # @TODO this is a hack to ensure that we don't have impossible examples
164
+ while optimal_path is not None and impossible:
165
+ inventory = remove_ancestor_items(target, inventory)
166
+ optimal_path = optimal_planner(target, inventory)
167
+
168
+ # assign to slots
169
+ inventory_list = assign_to_slots(inventory)
170
+ example = {
171
+ "inventory": inventory,
172
+ "slotted_inventory": inventory_list,
173
+ "target": target,
174
+ "num_distractors": num_distractors,
175
+ "impossible": impossible,
176
+ }
177
+ # either impossible and no path or not impossible and path exists
178
+ assert (impossible and optimal_path is None) or (
179
+ not impossible and optimal_path is not None
180
+ )
181
+
182
+ if not impossible:
183
+ example["optimal_path_length"] = len(optimal_path)
184
+ example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
185
+ example["inventory_trace"] = [i for (r, i) in optimal_path]
186
+ items_used, unique_items_used = calculate_stats_from_inventory_trace(
187
+ [example["inventory"]] + example["inventory_trace"]
188
+ )
189
+ example["items_used"] = items_used
190
+ example["unique_items_used"] = unique_items_used
191
+
192
+ return example
193
+
194
+
195
+ def calculate_stats_from_inventory_trace(
196
+ inventory_trace: list[dict],
197
+ ) -> tuple[int, int]:
198
+ total_items_used = 0
199
+ total_unique_items_used = 0
200
+
201
+ for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
202
+ diff = Counter(a) - Counter(b)
203
+ total_items_used += sum(diff.values())
204
+ total_unique_items_used += len(diff)
205
+
206
+ return total_items_used, total_unique_items_used
207
+
208
+
209
+ def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
210
+ random.seed(seed)
211
+ np.random.seed(seed)
212
+
213
+ dataset = []
214
+ for recipe_target in list(RECIPES.keys()):
215
+ if len(RECIPES[recipe_target]) == 0:
216
+ continue
217
+ for num_distractors in distractors:
218
+ for _ in range(num_examples):
219
+ example = construct_example(
220
+ target=recipe_target, num_distractors=num_distractors
221
+ )
222
+ dataset.append(example)
223
+
224
+ return dataset
models/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ from plancraft.models.base import ABCModel
2
+
3
+ from plancraft.config import EvalConfig
4
+ from plancraft.models.dummy import DummyModel
5
+ from plancraft.models.react import ReactModel
6
+ from plancraft.models.oracle import OracleModel
7
+ from plancraft.models.act import ActModel
8
+
9
+
10
+ def get_model(cfg: EvalConfig) -> ABCModel:
11
+ """
12
+ Factory get model (default: ReactModel)
13
+ """
14
+ if cfg.plancraft.mode == "dummy":
15
+ return DummyModel(cfg)
16
+ elif cfg.plancraft.mode == "oracle":
17
+ return OracleModel(cfg)
18
+ elif cfg.plancraft.mode == "act":
19
+ return ActModel(cfg)
20
+ else:
21
+ return ReactModel(cfg)
models/act.py ADDED
@@ -0,0 +1,184 @@
1
+ import copy
2
+ import torch
3
+ from dotenv import load_dotenv
4
+
5
+ from plancraft.config import EvalConfig
6
+ from plancraft.environments.actions import (
7
+ NoOp,
8
+ StopAction,
9
+ SymbolicAction,
10
+ )
11
+ from plancraft.models.base import ABCModel, History
12
+ from plancraft.models.bbox_model import IntegratedBoundingBoxModel
13
+ from plancraft.models.few_shot_images import load_prompt_images
14
+ from plancraft.models.generators import (
15
+ OAMGenerator,
16
+ OpenAIGenerator,
17
+ TransformersGenerator,
18
+ )
19
+ from plancraft.models.prompts import get_prompt_example, get_system_prompt
20
+ from plancraft.models.utils import (
21
+ convert_observation_to_message,
22
+ parse_content_response,
23
+ )
24
+
25
+
26
+ load_dotenv()
27
+
28
+
29
+ class ActModel(ABCModel):
30
+ """
31
+ Model that does action without thinking step
32
+ """
33
+
34
+ def __init__(self, cfg: EvalConfig):
35
+ assert (
36
+ cfg.plancraft.environment.symbolic_action_space
37
+ ), "Real action space unsupported"
38
+ self.cfg = cfg
39
+ self.env_is_multimodal = not cfg.plancraft.environment.symbolic
40
+ self.use_maskrcnn = cfg.plancraft.use_maskrcnn
41
+ self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
42
+ self.use_text_inventory = cfg.plancraft.use_text_inventory
43
+ self.use_images = cfg.plancraft.use_images
44
+
45
+ self.bbox_model = None
46
+ if self.use_maskrcnn:
47
+ assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
48
+ self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
49
+ "gautierdag/plancraft-maskrcnn"
50
+ )
51
+ self.bbox_model.eval()
52
+ if torch.cuda.is_available():
53
+ self.bbox_model.cuda()
54
+ # MaskRCNN is not multimodal model but a separate model
55
+
56
+ self.few_shot = cfg.plancraft.few_shot
57
+ self.use_system_prompt = cfg.plancraft.system_prompt
58
+ self.max_invalid_actions = 3
59
+
60
+ # underlying language model
61
+ if "gpt-4o" in cfg.plancraft.model:
62
+ self.use_multimodal_content_format = True
63
+ self.llm = OpenAIGenerator(
64
+ use_images=self.use_images, model_name=cfg.plancraft.model
65
+ )
66
+ elif "oam" in cfg.plancraft.model:
67
+ self.llm = OAMGenerator(model_name=cfg.plancraft.model)
68
+ else:
69
+ # model is transformers based
70
+ self.llm = TransformersGenerator(
71
+ model_name=cfg.plancraft.model,
72
+ tokenizer_name=cfg.plancraft.tokenizer,
73
+ quantize=cfg.plancraft.quantize,
74
+ use_hot_cache=cfg.plancraft.hot_cache,
75
+ adapter_name=cfg.plancraft.adapter,
76
+ )
77
+
78
+ self.prompt_images = []
79
+
80
+ self.valid_actions = cfg.plancraft.valid_actions
81
+ self.system_prompt_text = get_system_prompt(self.valid_actions)
82
+
83
+ examples = get_prompt_example(
84
+ self.valid_actions,
85
+ use_text_inventory=self.use_text_inventory,
86
+ use_multimodal_content_format=self.use_multimodal_content_format,
87
+ use_images=self.use_images,
88
+ )
89
+ if self.env_is_multimodal and self.use_images:
90
+ self.prompt_images = load_prompt_images()
91
+
92
+ if self.use_multimodal_content_format:
93
+ self.system_prompt = {
94
+ "role": "system",
95
+ "content": [
96
+ {"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
97
+ ],
98
+ }
99
+ else:
100
+ self.system_prompt = {
101
+ "role": "system",
102
+ "content": copy.deepcopy(self.system_prompt_text),
103
+ }
104
+
105
+ if not self.few_shot:
106
+ examples = []
107
+ if not self.use_system_prompt:
108
+ self.system_prompt = None
109
+
110
+ self.history = History(
111
+ initial_dialogue=examples,
112
+ use_multimodal_content_format=self.use_multimodal_content_format,
113
+ )
114
+
115
+ self.max_messages_window = cfg.plancraft.max_message_window
116
+ self.kv_cache = None
117
+
118
+ def reset_history(
119
+ self,
120
+ objective: str,
121
+ ):
122
+ examples = []
123
+ if self.few_shot:
124
+ examples = get_prompt_example(
125
+ self.valid_actions,
126
+ use_text_inventory=self.use_text_inventory,
127
+ use_multimodal_content_format=self.use_multimodal_content_format,
128
+ use_images=self.use_images,
129
+ )
130
+
131
+ self.history.reset(objective=objective, initial_dialogue=examples)
132
+ self.llm.reset()
133
+
134
+ def step(self, observation: dict) -> SymbolicAction | StopAction:
135
+ self.history.add_observation_to_history(observation)
136
+
137
+ # add observation to history
138
+ observation_message = convert_observation_to_message(
139
+ observation,
140
+ objective=self.history.objective,
141
+ bbox_model=self.bbox_model,
142
+ oam_model="oam" in self.llm.model_name,
143
+ use_text_inventory=self.use_text_inventory,
144
+ use_multimodal_content_format=self.use_multimodal_content_format,
145
+ use_images=self.use_images,
146
+ )
147
+ self.history.add_message_to_history(content=observation_message, role="user")
148
+
149
+ # Iterate until valid action
150
+ i = 0
151
+ while i < self.max_invalid_actions:
152
+ # add observation to history
153
+ message_window, image_window = self.llm.prepare_messages(
154
+ history=self.history,
155
+ max_messages_window=self.max_messages_window,
156
+ system_prompt=self.system_prompt,
157
+ prompt_images=self.prompt_images,
158
+ )
159
+ action_messages, action_token_used = self.llm.generate_unconstrained(
160
+ batch_messages=[message_window],
161
+ images=[image_window],
162
+ )
163
+ self.history.tokens_used += action_token_used
164
+
165
+ action_message = action_messages[0].split("\n")[0].strip()
166
+
167
+ self.history.add_message_to_history(
168
+ content=action_message, role="assistant"
169
+ )
170
+ response = parse_content_response(
171
+ action_message, valid_actions=self.valid_actions
172
+ )
173
+ if not isinstance(response, str):
174
+ # valid action
175
+ self.history.add_action_to_history(response)
176
+ return response
177
+
178
+ self.history.add_message_to_history(
179
+ content=response,
180
+ )
181
+ i += 1
182
+
183
+ # if no action is found after max_invalid_actions, default to useless move action
184
+ return NoOp()
models/base.py ADDED
@@ -0,0 +1,152 @@
1
+ import abc
2
+
3
+ from copy import copy
4
+ from collections import Counter
5
+
6
+ from plancraft.environments.actions import (
7
+ SymbolicMoveAction,
8
+ RealActionInteraction,
9
+ SymbolicSmeltAction,
10
+ )
11
+
12
+
13
+ class History:
14
+ def __init__(
15
+ self,
16
+ objective: str = "",
17
+ initial_dialogue: list[dict] = [],
18
+ use_multimodal_content_format=False,
19
+ ):
20
+ self.dialogue_history = initial_dialogue
21
+ self.initial_dialogue_length = len(initial_dialogue)
22
+ self.action_history = []
23
+ self.inventory_history = []
24
+ self.inventory_counters = []
25
+ self.images = []
26
+ self.objective = objective
27
+ self.tokens_used = 0
28
+ self.use_multimodal_content_format = use_multimodal_content_format
29
+
30
+ def add_message_to_history(self, content: str | dict, role="user"):
31
+ if role == "assistant":
32
+ print(content)
33
+
34
+ if isinstance(content, dict):
35
+ assert "content" in content, "content key not found in message"
36
+ content["role"] = role
37
+ self.dialogue_history.append(content)
38
+ else:
39
+ # fix for listed content type
40
+ if self.use_multimodal_content_format:
41
+ return self.add_message_to_history(
42
+ content={
43
+ "content": [{"type": "text", "text": content}],
44
+ "role": role,
45
+ },
46
+ role=role,
47
+ )
48
+ else:
49
+ self.dialogue_history.append({"role": role, "content": content})
50
+
51
+ def add_action_to_history(
52
+ self, action: SymbolicSmeltAction | RealActionInteraction | SymbolicMoveAction
53
+ ):
54
+ if action is None:
55
+ return
56
+ self.action_history.append(action.model_dump())
57
+
58
+ def add_inventory_to_history(self, inventory: list[dict[str, int]]):
59
+ self.inventory_history.append(inventory)
60
+
61
+ # count inventory
62
+ counter = Counter()
63
+ for item in inventory:
64
+ # ignore slot 0
65
+ if "slot" in item and item["slot"] == 0:
66
+ continue
67
+ if "index" in item and item["index"] == 0:
68
+ continue
69
+ counter[item["type"]] += item["quantity"]
70
+
71
+ self.inventory_counters.append(counter)
72
+
73
+ def add_image_to_history(self, image):
74
+ self.images.append(image)
75
+
76
+ def add_observation_to_history(self, observation: dict):
77
+ if observation is None:
78
+ return
79
+ if "inventory" in observation:
80
+ clean_inv = []
81
+ # remove empty slots
82
+ for item in observation["inventory"]:
83
+ if item["quantity"] > 0:
84
+ clean_inv.append(item)
85
+ self.add_inventory_to_history(clean_inv)
86
+ if "pov" in observation:
87
+ self.add_image_to_history(observation["pov"])
88
+
89
+ def __str__(self):
90
+ return str(self.dialogue_history)
91
+
92
+ def reset(self, objective: str = "", initial_dialogue: list[dict] = []):
93
+ self.dialogue_history = initial_dialogue
94
+ self.action_history = []
95
+ self.inventory_history = []
96
+ self.inventory_counters = []
97
+ self.images = []
98
+ self.objective = objective
99
+
100
+ def set_objective(self, objective: str):
101
+ self.objective = objective
102
+
103
+ def trace(self):
104
+ return {
105
+ "dialogue_history": copy(
106
+ self.dialogue_history[self.initial_dialogue_length :]
107
+ ),
108
+ "action_history": copy(self.action_history),
109
+ "inventory_history": copy(self.inventory_history),
110
+ "objective": copy(self.objective),
111
+ "tokens_used": copy(self.tokens_used),
112
+ }
113
+
114
+ @property
115
+ def num_steps(self):
116
+ return len(self.action_history)
117
+
118
+ def check_stuck(self, max_steps_no_change: int = 10) -> bool:
119
+ """
120
+ If inventory content does not change for max_steps_no_change steps
121
+ the agent is considered stuck.
122
+
123
+ With N=10, the oracle solver can still solve 100% of the examples
124
+ """
125
+ if len(self.inventory_counters) <= max_steps_no_change:
126
+ return False
127
+
128
+ return all(
129
+ c == self.inventory_counters[-max_steps_no_change - 1]
130
+ for c in self.inventory_counters[-max_steps_no_change - 1 :]
131
+ )
132
+
133
+
134
+ class ABCModel(abc.ABC):
135
+ """
136
+ Model class must implement the following methods to work with evaluator
137
+ """
138
+
139
+ @abc.abstractmethod
140
+ def step(
141
+ self, observation: list[dict]
142
+ ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
143
+ """
144
+ Model should output a valid action based on the 3 types available
145
+
146
+ Note this is a batch operation, so the model should return a list of actions
147
+ for each observation in the batch
148
+ """
149
+ raise NotImplementedError()
150
+
151
+ def reset_history(self, objective: str = ""):
152
+ self.history.reset(objective=objective)