plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
environments/sampler.py DELETED
@@ -1,224 +0,0 @@
1
- import math
2
- import random
3
- from collections import Counter
4
-
5
- import numpy as np
6
- from plancraft.environments.items import all_data, ALL_ITEMS
7
- from plancraft.environments.recipes import RECIPES
8
- from plancraft.environments.planner import optimal_planner, get_ancestors
9
-
10
-
11
- MAX_STACK_SIZE = {}
12
- for data_item in all_data["items"]:
13
- if data_item["stackable"]:
14
- MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
15
- else:
16
- MAX_STACK_SIZE[data_item["type"]] = 1
17
-
18
-
19
- def sample_distractors(
20
- exclude_set: set = None, num_distractors: int = 16
21
- ) -> dict[str, int]:
22
- distractors = {}
23
- while len(distractors) < num_distractors:
24
- item = random.choice(ALL_ITEMS)
25
- if exclude_set is not None and item in exclude_set:
26
- continue
27
- count = random.randint(1, MAX_STACK_SIZE[item])
28
- distractors[item] = count
29
- return distractors
30
-
31
-
32
- def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
33
- # slots available outside of crafting interface
34
- available_slots = list(range(10, 46))
35
- random.shuffle(available_slots)
36
- inventory_list = []
37
-
38
- for item, total_count in inventory.items():
39
- while total_count > 0:
40
- if len(available_slots) == 0:
41
- print("Not enough slots available")
42
- break
43
- slot = available_slots.pop()
44
- count_in_slot = min(total_count, MAX_STACK_SIZE[item])
45
- inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
46
- total_count -= count_in_slot
47
-
48
- return inventory_list
49
-
50
-
51
- def sample_recipes(
52
- target: str,
53
- overall_exclude_set: set,
54
- target_count: int = 1,
55
- current_depth=0,
56
- max_depth=20,
57
- ) -> tuple[set, set]:
58
- # stop if the depth is too high
59
- if current_depth > max_depth:
60
- return {}, overall_exclude_set
61
-
62
- # get all the recipes that can craft the target
63
- overall_exclude_set.update([target])
64
- local_exclude_set = set()
65
- random_recipes = []
66
- for r in RECIPES[target]:
67
- recipe_inputs, exclude_set = r.sample_inputs()
68
- # if inputs are already in the exclude set, skip this recipe (ensures no cycle)
69
- if exclude_set.intersection(overall_exclude_set):
70
- return {}, overall_exclude_set
71
- local_exclude_set.update(exclude_set)
72
- random_recipes.append((r, recipe_inputs))
73
-
74
- overall_exclude_set |= local_exclude_set
75
-
76
- # no recipes found
77
- if len(random_recipes) == 0:
78
- return {}, overall_exclude_set
79
-
80
- # sample a random recipe
81
- random_recipe = random.choice(random_recipes)
82
- recipe, start_inputs = random_recipe
83
-
84
- # recipe will not produce enough
85
- if recipe.result.count < target_count:
86
- # must do recipe X times
87
- recipe_multiplier = math.ceil(target_count / recipe.result.count)
88
- start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
89
-
90
- for input_item in list(start_inputs.keys()):
91
- # randomize depth first search to end early
92
- if random.choice([True, False]):
93
- continue
94
-
95
- children_recipe_inputs, updated_exclude_set = sample_recipes(
96
- target=input_item,
97
- overall_exclude_set=overall_exclude_set,
98
- target_count=start_inputs[input_item],
99
- current_depth=current_depth + 1,
100
- )
101
- if len(children_recipe_inputs) == 0:
102
- continue
103
-
104
- overall_exclude_set.update(updated_exclude_set)
105
-
106
- # remove recipe input item since we are crafting it
107
- start_inputs[input_item] = 0
108
-
109
- # add the children recipe inputs
110
- for item, count in children_recipe_inputs.items():
111
- start_inputs[item] = start_inputs.get(item, 0) + count
112
-
113
- overall_exclude_set = overall_exclude_set - {None}
114
- start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
115
-
116
- return start_inputs, overall_exclude_set
117
-
118
-
119
- def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
120
- ancestors = set(get_ancestors(target))
121
- possible_items = set(inventory.keys())
122
- items_to_remove = list(ancestors.intersection(possible_items))
123
- num_items = random.randint(1, len(items_to_remove))
124
- for item in random.sample(items_to_remove, num_items):
125
- count_to_remove = random.randint(1, inventory[item])
126
- inventory[item] -= count_to_remove
127
- if inventory[item] == 0:
128
- del inventory[item]
129
- return inventory
130
-
131
-
132
- def construct_example(
133
- target: str,
134
- num_distractors: 16,
135
- impossible=False,
136
- ) -> list[dict]:
137
- """
138
- For a given target object, number of distractors, and impossible flag
139
- Return a dictionary with the start inventory for the crafting task
140
-
141
- The crafting task should be to craft the target, the inventory should contain
142
- the resources required for the recipe to be crafted.
143
-
144
- The number of distractors are how many random items should be added to the inventory.
145
-
146
- If impossible is True, the target item should not be craftable with the given inventory.
147
- """
148
-
149
- # sample the recipe
150
- inventory, overall_exclude_set = sample_recipes(target, set())
151
- if impossible:
152
- # if impossible then remove one or more items from the inventory
153
- inventory = remove_ancestor_items(
154
- target,
155
- inventory,
156
- )
157
-
158
- # add distractors to the inventory
159
- distractors = sample_distractors(overall_exclude_set, num_distractors)
160
- inventory.update(distractors)
161
-
162
- optimal_path = optimal_planner(target, inventory)
163
- # @TODO this is a hack to ensure that we don't have impossible examples
164
- while optimal_path is not None and impossible:
165
- inventory = remove_ancestor_items(target, inventory)
166
- optimal_path = optimal_planner(target, inventory)
167
-
168
- # assign to slots
169
- inventory_list = assign_to_slots(inventory)
170
- example = {
171
- "inventory": inventory,
172
- "slotted_inventory": inventory_list,
173
- "target": target,
174
- "num_distractors": num_distractors,
175
- "impossible": impossible,
176
- }
177
- # either impossible and no path or not impossible and path exists
178
- assert (impossible and optimal_path is None) or (
179
- not impossible and optimal_path is not None
180
- )
181
-
182
- if not impossible:
183
- example["optimal_path_length"] = len(optimal_path)
184
- example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
185
- example["inventory_trace"] = [i for (r, i) in optimal_path]
186
- items_used, unique_items_used = calculate_stats_from_inventory_trace(
187
- [example["inventory"]] + example["inventory_trace"]
188
- )
189
- example["items_used"] = items_used
190
- example["unique_items_used"] = unique_items_used
191
-
192
- return example
193
-
194
-
195
- def calculate_stats_from_inventory_trace(
196
- inventory_trace: list[dict],
197
- ) -> tuple[int, int]:
198
- total_items_used = 0
199
- total_unique_items_used = 0
200
-
201
- for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
202
- diff = Counter(a) - Counter(b)
203
- total_items_used += sum(diff.values())
204
- total_unique_items_used += len(diff)
205
-
206
- return total_items_used, total_unique_items_used
207
-
208
-
209
- def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
210
- random.seed(seed)
211
- np.random.seed(seed)
212
-
213
- dataset = []
214
- for recipe_target in list(RECIPES.keys()):
215
- if len(RECIPES[recipe_target]) == 0:
216
- continue
217
- for num_distractors in distractors:
218
- for _ in range(num_examples):
219
- example = construct_example(
220
- target=recipe_target, num_distractors=num_distractors
221
- )
222
- dataset.append(example)
223
-
224
- return dataset
models/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- from plancraft.models.base import ABCModel
2
-
3
- from plancraft.config import EvalConfig
4
- from plancraft.models.dummy import DummyModel
5
- from plancraft.models.react import ReactModel
6
- from plancraft.models.oracle import OracleModel
7
- from plancraft.models.act import ActModel
8
-
9
-
10
- def get_model(cfg: EvalConfig) -> ABCModel:
11
- """
12
- Factory get model (default: ReactModel)
13
- """
14
- if cfg.plancraft.mode == "dummy":
15
- return DummyModel(cfg)
16
- elif cfg.plancraft.mode == "oracle":
17
- return OracleModel(cfg)
18
- elif cfg.plancraft.mode == "act":
19
- return ActModel(cfg)
20
- else:
21
- return ReactModel(cfg)
models/act.py DELETED
@@ -1,184 +0,0 @@
1
- import copy
2
- import torch
3
- from dotenv import load_dotenv
4
-
5
- from plancraft.config import EvalConfig
6
- from plancraft.environments.actions import (
7
- NoOp,
8
- StopAction,
9
- SymbolicAction,
10
- )
11
- from plancraft.models.base import ABCModel, History
12
- from plancraft.models.bbox_model import IntegratedBoundingBoxModel
13
- from plancraft.models.few_shot_images import load_prompt_images
14
- from plancraft.models.generators import (
15
- OAMGenerator,
16
- OpenAIGenerator,
17
- TransformersGenerator,
18
- )
19
- from plancraft.models.prompts import get_prompt_example, get_system_prompt
20
- from plancraft.models.utils import (
21
- convert_observation_to_message,
22
- parse_content_response,
23
- )
24
-
25
-
26
- load_dotenv()
27
-
28
-
29
- class ActModel(ABCModel):
30
- """
31
- Model that does action without thinking step
32
- """
33
-
34
- def __init__(self, cfg: EvalConfig):
35
- assert (
36
- cfg.plancraft.environment.symbolic_action_space
37
- ), "Real action space unsupported"
38
- self.cfg = cfg
39
- self.env_is_multimodal = not cfg.plancraft.environment.symbolic
40
- self.use_maskrcnn = cfg.plancraft.use_maskrcnn
41
- self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
42
- self.use_text_inventory = cfg.plancraft.use_text_inventory
43
- self.use_images = cfg.plancraft.use_images
44
-
45
- self.bbox_model = None
46
- if self.use_maskrcnn:
47
- assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
48
- self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
49
- "gautierdag/plancraft-maskrcnn"
50
- )
51
- self.bbox_model.eval()
52
- if torch.cuda.is_available():
53
- self.bbox_model.cuda()
54
- # MaskRCNN is not multimodal model but a separate model
55
-
56
- self.few_shot = cfg.plancraft.few_shot
57
- self.use_system_prompt = cfg.plancraft.system_prompt
58
- self.max_invalid_actions = 3
59
-
60
- # underlying language model
61
- if "gpt-4o" in cfg.plancraft.model:
62
- self.use_multimodal_content_format = True
63
- self.llm = OpenAIGenerator(
64
- use_images=self.use_images, model_name=cfg.plancraft.model
65
- )
66
- elif "oam" in cfg.plancraft.model:
67
- self.llm = OAMGenerator(model_name=cfg.plancraft.model)
68
- else:
69
- # model is transformers based
70
- self.llm = TransformersGenerator(
71
- model_name=cfg.plancraft.model,
72
- tokenizer_name=cfg.plancraft.tokenizer,
73
- quantize=cfg.plancraft.quantize,
74
- use_hot_cache=cfg.plancraft.hot_cache,
75
- adapter_name=cfg.plancraft.adapter,
76
- )
77
-
78
- self.prompt_images = []
79
-
80
- self.valid_actions = cfg.plancraft.valid_actions
81
- self.system_prompt_text = get_system_prompt(self.valid_actions)
82
-
83
- examples = get_prompt_example(
84
- self.valid_actions,
85
- use_text_inventory=self.use_text_inventory,
86
- use_multimodal_content_format=self.use_multimodal_content_format,
87
- use_images=self.use_images,
88
- )
89
- if self.env_is_multimodal and self.use_images:
90
- self.prompt_images = load_prompt_images()
91
-
92
- if self.use_multimodal_content_format:
93
- self.system_prompt = {
94
- "role": "system",
95
- "content": [
96
- {"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
97
- ],
98
- }
99
- else:
100
- self.system_prompt = {
101
- "role": "system",
102
- "content": copy.deepcopy(self.system_prompt_text),
103
- }
104
-
105
- if not self.few_shot:
106
- examples = []
107
- if not self.use_system_prompt:
108
- self.system_prompt = None
109
-
110
- self.history = History(
111
- initial_dialogue=examples,
112
- use_multimodal_content_format=self.use_multimodal_content_format,
113
- )
114
-
115
- self.max_messages_window = cfg.plancraft.max_message_window
116
- self.kv_cache = None
117
-
118
- def reset_history(
119
- self,
120
- objective: str,
121
- ):
122
- examples = []
123
- if self.few_shot:
124
- examples = get_prompt_example(
125
- self.valid_actions,
126
- use_text_inventory=self.use_text_inventory,
127
- use_multimodal_content_format=self.use_multimodal_content_format,
128
- use_images=self.use_images,
129
- )
130
-
131
- self.history.reset(objective=objective, initial_dialogue=examples)
132
- self.llm.reset()
133
-
134
- def step(self, observation: dict) -> SymbolicAction | StopAction:
135
- self.history.add_observation_to_history(observation)
136
-
137
- # add observation to history
138
- observation_message = convert_observation_to_message(
139
- observation,
140
- objective=self.history.objective,
141
- bbox_model=self.bbox_model,
142
- oam_model="oam" in self.llm.model_name,
143
- use_text_inventory=self.use_text_inventory,
144
- use_multimodal_content_format=self.use_multimodal_content_format,
145
- use_images=self.use_images,
146
- )
147
- self.history.add_message_to_history(content=observation_message, role="user")
148
-
149
- # Iterate until valid action
150
- i = 0
151
- while i < self.max_invalid_actions:
152
- # add observation to history
153
- message_window, image_window = self.llm.prepare_messages(
154
- history=self.history,
155
- max_messages_window=self.max_messages_window,
156
- system_prompt=self.system_prompt,
157
- prompt_images=self.prompt_images,
158
- )
159
- action_messages, action_token_used = self.llm.generate_unconstrained(
160
- batch_messages=[message_window],
161
- images=[image_window],
162
- )
163
- self.history.tokens_used += action_token_used
164
-
165
- action_message = action_messages[0].split("\n")[0].strip()
166
-
167
- self.history.add_message_to_history(
168
- content=action_message, role="assistant"
169
- )
170
- response = parse_content_response(
171
- action_message, valid_actions=self.valid_actions
172
- )
173
- if not isinstance(response, str):
174
- # valid action
175
- self.history.add_action_to_history(response)
176
- return response
177
-
178
- self.history.add_message_to_history(
179
- content=response,
180
- )
181
- i += 1
182
-
183
- # if no action is found after max_invalid_actions, default to useless move action
184
- return NoOp()
models/base.py DELETED
@@ -1,152 +0,0 @@
1
- import abc
2
-
3
- from copy import copy
4
- from collections import Counter
5
-
6
- from plancraft.environments.actions import (
7
- SymbolicMoveAction,
8
- RealActionInteraction,
9
- SymbolicSmeltAction,
10
- )
11
-
12
-
13
- class History:
14
- def __init__(
15
- self,
16
- objective: str = "",
17
- initial_dialogue: list[dict] = [],
18
- use_multimodal_content_format=False,
19
- ):
20
- self.dialogue_history = initial_dialogue
21
- self.initial_dialogue_length = len(initial_dialogue)
22
- self.action_history = []
23
- self.inventory_history = []
24
- self.inventory_counters = []
25
- self.images = []
26
- self.objective = objective
27
- self.tokens_used = 0
28
- self.use_multimodal_content_format = use_multimodal_content_format
29
-
30
- def add_message_to_history(self, content: str | dict, role="user"):
31
- if role == "assistant":
32
- print(content)
33
-
34
- if isinstance(content, dict):
35
- assert "content" in content, "content key not found in message"
36
- content["role"] = role
37
- self.dialogue_history.append(content)
38
- else:
39
- # fix for listed content type
40
- if self.use_multimodal_content_format:
41
- return self.add_message_to_history(
42
- content={
43
- "content": [{"type": "text", "text": content}],
44
- "role": role,
45
- },
46
- role=role,
47
- )
48
- else:
49
- self.dialogue_history.append({"role": role, "content": content})
50
-
51
- def add_action_to_history(
52
- self, action: SymbolicSmeltAction | RealActionInteraction | SymbolicMoveAction
53
- ):
54
- if action is None:
55
- return
56
- self.action_history.append(action.model_dump())
57
-
58
- def add_inventory_to_history(self, inventory: list[dict[str, int]]):
59
- self.inventory_history.append(inventory)
60
-
61
- # count inventory
62
- counter = Counter()
63
- for item in inventory:
64
- # ignore slot 0
65
- if "slot" in item and item["slot"] == 0:
66
- continue
67
- if "index" in item and item["index"] == 0:
68
- continue
69
- counter[item["type"]] += item["quantity"]
70
-
71
- self.inventory_counters.append(counter)
72
-
73
- def add_image_to_history(self, image):
74
- self.images.append(image)
75
-
76
- def add_observation_to_history(self, observation: dict):
77
- if observation is None:
78
- return
79
- if "inventory" in observation:
80
- clean_inv = []
81
- # remove empty slots
82
- for item in observation["inventory"]:
83
- if item["quantity"] > 0:
84
- clean_inv.append(item)
85
- self.add_inventory_to_history(clean_inv)
86
- if "pov" in observation:
87
- self.add_image_to_history(observation["pov"])
88
-
89
- def __str__(self):
90
- return str(self.dialogue_history)
91
-
92
- def reset(self, objective: str = "", initial_dialogue: list[dict] = []):
93
- self.dialogue_history = initial_dialogue
94
- self.action_history = []
95
- self.inventory_history = []
96
- self.inventory_counters = []
97
- self.images = []
98
- self.objective = objective
99
-
100
- def set_objective(self, objective: str):
101
- self.objective = objective
102
-
103
- def trace(self):
104
- return {
105
- "dialogue_history": copy(
106
- self.dialogue_history[self.initial_dialogue_length :]
107
- ),
108
- "action_history": copy(self.action_history),
109
- "inventory_history": copy(self.inventory_history),
110
- "objective": copy(self.objective),
111
- "tokens_used": copy(self.tokens_used),
112
- }
113
-
114
- @property
115
- def num_steps(self):
116
- return len(self.action_history)
117
-
118
- def check_stuck(self, max_steps_no_change: int = 10) -> bool:
119
- """
120
- If inventory content does not change for max_steps_no_change steps
121
- the agent is considered stuck.
122
-
123
- With N=10, the oracle solver can still solve 100% of the examples
124
- """
125
- if len(self.inventory_counters) <= max_steps_no_change:
126
- return False
127
-
128
- return all(
129
- c == self.inventory_counters[-max_steps_no_change - 1]
130
- for c in self.inventory_counters[-max_steps_no_change - 1 :]
131
- )
132
-
133
-
134
- class ABCModel(abc.ABC):
135
- """
136
- Model class must implement the following methods to work with evaluator
137
- """
138
-
139
- @abc.abstractmethod
140
- def step(
141
- self, observation: list[dict]
142
- ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
143
- """
144
- Model should output a valid action based on the 3 types available
145
-
146
- Note this is a batch operation, so the model should return a list of actions
147
- for each observation in the batch
148
- """
149
- raise NotImplementedError()
150
-
151
- def reset_history(self, objective: str = ""):
152
- self.history.reset(objective=objective)