plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,224 @@
1
+ import math
2
+ import random
3
+ from collections import Counter
4
+
5
+ import numpy as np
6
+ from plancraft.environments.items import all_data, ALL_ITEMS
7
+ from plancraft.environments.recipes import RECIPES
8
+ from plancraft.environments.planner import optimal_planner, get_ancestors
9
+
10
+
11
+ MAX_STACK_SIZE = {}
12
+ for data_item in all_data["items"]:
13
+ if data_item["stackable"]:
14
+ MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
15
+ else:
16
+ MAX_STACK_SIZE[data_item["type"]] = 1
17
+
18
+
19
+ def sample_distractors(
20
+ exclude_set: set = None, num_distractors: int = 16
21
+ ) -> dict[str, int]:
22
+ distractors = {}
23
+ while len(distractors) < num_distractors:
24
+ item = random.choice(ALL_ITEMS)
25
+ if exclude_set is not None and item in exclude_set:
26
+ continue
27
+ count = random.randint(1, MAX_STACK_SIZE[item])
28
+ distractors[item] = count
29
+ return distractors
30
+
31
+
32
+ def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
33
+ # slots available outside of crafting interface
34
+ available_slots = list(range(10, 46))
35
+ random.shuffle(available_slots)
36
+ inventory_list = []
37
+
38
+ for item, total_count in inventory.items():
39
+ while total_count > 0:
40
+ if len(available_slots) == 0:
41
+ print("Not enough slots available")
42
+ break
43
+ slot = available_slots.pop()
44
+ count_in_slot = min(total_count, MAX_STACK_SIZE[item])
45
+ inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
46
+ total_count -= count_in_slot
47
+
48
+ return inventory_list
49
+
50
+
51
+ def sample_recipes(
52
+ target: str,
53
+ overall_exclude_set: set,
54
+ target_count: int = 1,
55
+ current_depth=0,
56
+ max_depth=20,
57
+ ) -> tuple[set, set]:
58
+ # stop if the depth is too high
59
+ if current_depth > max_depth:
60
+ return {}, overall_exclude_set
61
+
62
+ # get all the recipes that can craft the target
63
+ overall_exclude_set.update([target])
64
+ local_exclude_set = set()
65
+ random_recipes = []
66
+ for r in RECIPES[target]:
67
+ recipe_inputs, exclude_set = r.sample_inputs()
68
+ # if inputs are already in the exclude set, skip this recipe (ensures no cycle)
69
+ if exclude_set.intersection(overall_exclude_set):
70
+ return {}, overall_exclude_set
71
+ local_exclude_set.update(exclude_set)
72
+ random_recipes.append((r, recipe_inputs))
73
+
74
+ overall_exclude_set |= local_exclude_set
75
+
76
+ # no recipes found
77
+ if len(random_recipes) == 0:
78
+ return {}, overall_exclude_set
79
+
80
+ # sample a random recipe
81
+ random_recipe = random.choice(random_recipes)
82
+ recipe, start_inputs = random_recipe
83
+
84
+ # recipe will not produce enough
85
+ if recipe.result.count < target_count:
86
+ # must do recipe X times
87
+ recipe_multiplier = math.ceil(target_count / recipe.result.count)
88
+ start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
89
+
90
+ for input_item in list(start_inputs.keys()):
91
+ # randomize depth first search to end early
92
+ if random.choice([True, False]):
93
+ continue
94
+
95
+ children_recipe_inputs, updated_exclude_set = sample_recipes(
96
+ target=input_item,
97
+ overall_exclude_set=overall_exclude_set,
98
+ target_count=start_inputs[input_item],
99
+ current_depth=current_depth + 1,
100
+ )
101
+ if len(children_recipe_inputs) == 0:
102
+ continue
103
+
104
+ overall_exclude_set.update(updated_exclude_set)
105
+
106
+ # remove recipe input item since we are crafting it
107
+ start_inputs[input_item] = 0
108
+
109
+ # add the children recipe inputs
110
+ for item, count in children_recipe_inputs.items():
111
+ start_inputs[item] = start_inputs.get(item, 0) + count
112
+
113
+ overall_exclude_set = overall_exclude_set - {None}
114
+ start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
115
+
116
+ return start_inputs, overall_exclude_set
117
+
118
+
119
+ def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
120
+ ancestors = set(get_ancestors(target))
121
+ possible_items = set(inventory.keys())
122
+ items_to_remove = list(ancestors.intersection(possible_items))
123
+ num_items = random.randint(1, len(items_to_remove))
124
+ for item in random.sample(items_to_remove, num_items):
125
+ count_to_remove = random.randint(1, inventory[item])
126
+ inventory[item] -= count_to_remove
127
+ if inventory[item] == 0:
128
+ del inventory[item]
129
+ return inventory
130
+
131
+
132
+ def construct_example(
133
+ target: str,
134
+ num_distractors: 16,
135
+ impossible=False,
136
+ ) -> list[dict]:
137
+ """
138
+ For a given target object, number of distractors, and impossible flag
139
+ Return a dictionary with the start inventory for the crafting task
140
+
141
+ The crafting task should be to craft the target, the inventory should contain
142
+ the resources required for the recipe to be crafted.
143
+
144
+ The number of distractors are how many random items should be added to the inventory.
145
+
146
+ If impossible is True, the target item should not be craftable with the given inventory.
147
+ """
148
+
149
+ # sample the recipe
150
+ inventory, overall_exclude_set = sample_recipes(target, set())
151
+ if impossible:
152
+ # if impossible then remove one or more items from the inventory
153
+ inventory = remove_ancestor_items(
154
+ target,
155
+ inventory,
156
+ )
157
+
158
+ # add distractors to the inventory
159
+ distractors = sample_distractors(overall_exclude_set, num_distractors)
160
+ inventory.update(distractors)
161
+
162
+ optimal_path = optimal_planner(target, inventory)
163
+ # @TODO this is a hack to ensure that we don't have impossible examples
164
+ while optimal_path is not None and impossible:
165
+ inventory = remove_ancestor_items(target, inventory)
166
+ optimal_path = optimal_planner(target, inventory)
167
+
168
+ # assign to slots
169
+ inventory_list = assign_to_slots(inventory)
170
+ example = {
171
+ "inventory": inventory,
172
+ "slotted_inventory": inventory_list,
173
+ "target": target,
174
+ "num_distractors": num_distractors,
175
+ "impossible": impossible,
176
+ }
177
+ # either impossible and no path or not impossible and path exists
178
+ assert (impossible and optimal_path is None) or (
179
+ not impossible and optimal_path is not None
180
+ )
181
+
182
+ if not impossible:
183
+ example["optimal_path_length"] = len(optimal_path)
184
+ example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
185
+ example["inventory_trace"] = [i for (r, i) in optimal_path]
186
+ items_used, unique_items_used = calculate_stats_from_inventory_trace(
187
+ [example["inventory"]] + example["inventory_trace"]
188
+ )
189
+ example["items_used"] = items_used
190
+ example["unique_items_used"] = unique_items_used
191
+
192
+ return example
193
+
194
+
195
+ def calculate_stats_from_inventory_trace(
196
+ inventory_trace: list[dict],
197
+ ) -> tuple[int, int]:
198
+ total_items_used = 0
199
+ total_unique_items_used = 0
200
+
201
+ for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
202
+ diff = Counter(a) - Counter(b)
203
+ total_items_used += sum(diff.values())
204
+ total_unique_items_used += len(diff)
205
+
206
+ return total_items_used, total_unique_items_used
207
+
208
+
209
+ def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
210
+ random.seed(seed)
211
+ np.random.seed(seed)
212
+
213
+ dataset = []
214
+ for recipe_target in list(RECIPES.keys()):
215
+ if len(RECIPES[recipe_target]) == 0:
216
+ continue
217
+ for num_distractors in distractors:
218
+ for _ in range(num_examples):
219
+ example = construct_example(
220
+ target=recipe_target, num_distractors=num_distractors
221
+ )
222
+ dataset.append(example)
223
+
224
+ return dataset
plancraft/evaluator.py ADDED
@@ -0,0 +1,273 @@
1
+ import json
2
+ import os
3
+ import random
4
+ import string
5
+ import time
6
+
7
+ import imageio
8
+ import pandas as pd
9
+ import torch
10
+ import wandb
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ from plancraft.config import EvalConfig, PlancraftExample
15
+ from plancraft.environments.actions import StopAction
16
+ from plancraft.environments.env_real import RealPlancraft
17
+ from plancraft.environments.env_symbolic import SymbolicPlancraft
18
+ from plancraft.models import get_model
19
+
20
+ wandb.require("core")
21
+
22
+
23
+ class Evaluator:
24
+ """
25
+ The evaluator class handles the environment loop and model interaction
26
+
27
+ The environment is created based on the configuration and the examples are loaded from the dataset.
28
+ """
29
+
30
+ def __init__(self, cfg: EvalConfig):
31
+ self.cfg = cfg
32
+ self.output_dir = (
33
+ f"{cfg.plancraft.output_dir}/{self.evaluator_name()}/{cfg.plancraft.split}"
34
+ )
35
+ self.generation_number = 0
36
+
37
+ self.examples: list[PlancraftExample] = self.load_dataset(cfg.plancraft.split)
38
+
39
+ self.environment = self.create_env(cfg)
40
+ self.model = get_model(cfg)
41
+
42
+ self.record_frames = not (cfg.plancraft.environment.symbolic)
43
+
44
+ # no_op action
45
+ self.no_op = self.environment.action_space.no_op()
46
+
47
+ def evaluator_name(self) -> str:
48
+ symb_str = "real"
49
+ if self.cfg.plancraft.environment.symbolic:
50
+ symb_str = "symb"
51
+
52
+ if self.cfg.plancraft.use_maskrcnn:
53
+ symb_str += "_mrcnn"
54
+
55
+ model_name = self.cfg.plancraft.model.split("/")[-1]
56
+ if self.cfg.plancraft.adapter != "":
57
+ model_name = self.cfg.plancraft.adapter.split("/")[-1]
58
+
59
+ mode = self.cfg.plancraft.mode
60
+ if mode in ["dummy", "oracle"]:
61
+ return f"{mode}_{symb_str}"
62
+
63
+ actions = "|".join(self.cfg.plancraft.valid_actions)
64
+ return f"{self.cfg.plancraft.mode}_{symb_str}_{model_name}_{actions}"
65
+
66
+ def save_results_dict(self, example: PlancraftExample, results_dict: dict):
67
+ output_dir = f"{self.output_dir}/{self.generation_number}"
68
+ os.makedirs(output_dir, exist_ok=True)
69
+ json_path = f"{output_dir}/{example.id}.json"
70
+ with open(json_path, "w") as f:
71
+ json.dump(results_dict, f, indent=4)
72
+ wandb.save(json_path, policy="now")
73
+
74
+ def save_images(self, example: PlancraftExample, frames: list):
75
+ if len(frames) == 0:
76
+ return
77
+ output_dir = f"{self.output_dir}/{self.generation_number}"
78
+ os.makedirs(output_dir, exist_ok=True)
79
+ imageio.mimsave(f"{output_dir}/{example.id}.gif", frames)
80
+ # upload to wandb
81
+ wandb.save(f"{output_dir}/{example.id}.gif", policy="now")
82
+
83
+ def load_results_dict(self, example: PlancraftExample) -> dict:
84
+ path = f"{self.output_dir}/{self.generation_number}/{example.id}.json"
85
+ if not os.path.exists(path) or not self.cfg.plancraft.resume:
86
+ return None
87
+ with open(path, "r") as f:
88
+ return json.load(f)
89
+
90
+ def create_env(self, cfg: EvalConfig) -> RealPlancraft | SymbolicPlancraft:
91
+ if cfg.plancraft.environment.symbolic:
92
+ return SymbolicPlancraft(inventory=[])
93
+ return RealPlancraft(
94
+ inventory=[],
95
+ symbolic_action_space=cfg.plancraft.environment.symbolic_action_space,
96
+ symbolic_observation_space=cfg.plancraft.environment.symbolic_observation_space,
97
+ preferred_spawn_biome=cfg.plancraft.environment.preferred_spawn_biome,
98
+ resolution=cfg.plancraft.environment.resolution,
99
+ )
100
+
101
+ def close(self):
102
+ self.environment.close()
103
+
104
+ def load_dataset(self, dataset_split: str) -> list[PlancraftExample]:
105
+ with open(f"data/{dataset_split}.json", "r") as f:
106
+ dataset = json.load(f)
107
+ return [PlancraftExample(**example) for example in dataset]
108
+
109
+ def reset(
110
+ self,
111
+ example: PlancraftExample,
112
+ ):
113
+ current_inventory = example.slotted_inventory
114
+ self.environment.fast_reset(new_inventory=current_inventory)
115
+ # do a no op to an initial observation
116
+ obs, _, _, _ = self.environment.step(self.no_op)
117
+ # assert that the inventory is correct
118
+ if "inventory" in obs:
119
+ for item in current_inventory:
120
+ slot = item["slot"]
121
+ if (
122
+ obs["inventory"][slot]["type"] != item["type"]
123
+ or obs["inventory"][slot]["quantity"] != item["quantity"]
124
+ ) and item["type"] != "air":
125
+ logger.warning(f"Inventory does not match expected for slot {slot}")
126
+ logger.warning(f"Expected {item}")
127
+ logger.warning(f"Got {obs['inventory'][slot]}")
128
+ # try again
129
+ self.reset(example)
130
+
131
+ objective = f"Craft an item of type: {example.target}"
132
+ self.model.reset_history(objective=objective)
133
+
134
+ def check_done(self, inventory: list[dict[str, int]], target: str):
135
+ """
136
+ Check that target object is obtained
137
+ """
138
+ for item in inventory:
139
+ if target == item["type"]:
140
+ # ensure item is taken out of crafting slot
141
+ if "slot" in item and item["slot"] != 0:
142
+ return True
143
+ if "index" in item and item["index"] != 0:
144
+ return True
145
+ return False
146
+
147
+ @torch.no_grad()
148
+ def eval_all_examples(self, progress_bar=False) -> list:
149
+ results = []
150
+ action = self.no_op.copy()
151
+
152
+ pbar = tqdm(
153
+ total=len(self.examples),
154
+ disable=not progress_bar,
155
+ )
156
+ correct = 0
157
+ count = 0
158
+
159
+ for example in self.examples:
160
+ if resume_result := self.load_results_dict(example):
161
+ pbar.update(self.cfg.plancraft.max_steps)
162
+ results.append(resume_result)
163
+ continue
164
+
165
+ success = False
166
+
167
+ self.reset(example)
168
+ action = self.no_op.copy()
169
+
170
+ while (
171
+ not self.model.history.check_stuck()
172
+ and self.model.history.num_steps < self.cfg.plancraft.max_steps
173
+ ):
174
+ # if the action is stop then we end the episode
175
+ if isinstance(action, StopAction):
176
+ # if the action is stop and task is impossible then success
177
+ # otherwise we should not have stopped
178
+ success = example.impossible
179
+ break
180
+
181
+ # step action
182
+ observation, _, _, _ = self.environment.step(action)
183
+
184
+ # check if the episode is done
185
+ success = self.check_done(observation["inventory"], example.target)
186
+ # exit if success
187
+ if success:
188
+ break
189
+
190
+ # predict next action
191
+ action = self.model.step(observation)
192
+
193
+ # save results and reset
194
+ result = {
195
+ "success": success,
196
+ "recipe_type": example.recipe_type,
197
+ "number_of_steps": self.model.history.num_steps,
198
+ "model_trace": self.model.history.trace(),
199
+ "example_id": example.id,
200
+ "impossible": example.impossible,
201
+ }
202
+ results.append(result)
203
+ self.save_results_dict(example, result)
204
+ self.save_images(example, self.model.history.images)
205
+
206
+ correct += int(result["success"])
207
+ count += 1
208
+
209
+ acc = correct / count
210
+ pbar.set_postfix(correct=correct, count=count, acc=acc)
211
+ pbar.update(1)
212
+
213
+ return results
214
+
215
+ def eval_all(self):
216
+ logger.info(
217
+ f"Running evaluation over {len(self.examples)} examples {self.cfg.plancraft.num_generations} times."
218
+ )
219
+ run_name = (
220
+ f"{self.evaluator_name()} {self.cfg.plancraft.split}".replace(" ", "_")
221
+ .replace(".", "_")
222
+ .strip()
223
+ )
224
+
225
+ for n in range(self.cfg.plancraft.num_generations):
226
+ logger.info(f"Generation {n+1}/{self.cfg.plancraft.num_generations}")
227
+ run_id = "".join(random.choices(string.ascii_lowercase, k=5))
228
+ generation_run_name = run_name + f"_{run_id}"
229
+
230
+ wandb.init(
231
+ name=generation_run_name,
232
+ project=self.cfg.wandb.project,
233
+ entity=self.cfg.wandb.entity,
234
+ mode=self.cfg.wandb.mode,
235
+ group=self.cfg.plancraft.model,
236
+ job_type=self.cfg.plancraft.mode,
237
+ config=self.cfg.model_dump(),
238
+ )
239
+ time_now = time.time()
240
+
241
+ results_list = self.eval_all_examples(progress_bar=True)
242
+
243
+ results_df = pd.DataFrame(results_list)
244
+
245
+ output = {
246
+ "avg_success_rate": results_df["success"].mean(),
247
+ "avg_number_of_steps": results_df["number_of_steps"].mean(),
248
+ "avg_num_tokens_used": results_df["model_trace"]
249
+ .apply(pd.Series)["tokens_used"]
250
+ .mean(),
251
+ }
252
+
253
+ # calculate success rate for each recipe type
254
+ recipe_types = results_df["recipe_type"].unique()
255
+ for recipe_type in recipe_types:
256
+ mask = results_df["recipe_type"] == recipe_type
257
+ success_rate = results_df[mask]["success"].mean()
258
+ output[f"{recipe_type}_success_rate"] = success_rate
259
+
260
+ time_elapsed = time.time() - time_now
261
+ logger.info(f"Time elapsed: {time_elapsed:.2f}s")
262
+
263
+ logger.info(output)
264
+ wandb.log(output)
265
+ table = wandb.Table(
266
+ dataframe=results_df[["success", "number_of_steps", "example_id"]]
267
+ )
268
+ wandb.log({"results": table})
269
+ wandb.finish()
270
+
271
+ self.generation_number += 1
272
+
273
+ logger.info("Done")
@@ -0,0 +1,21 @@
1
+ from plancraft.models.base import ABCModel
2
+
3
+ from plancraft.config import EvalConfig
4
+ from plancraft.models.dummy import DummyModel
5
+ from plancraft.models.react import ReactModel
6
+ from plancraft.models.oracle import OracleModel
7
+ from plancraft.models.act import ActModel
8
+
9
+
10
+ def get_model(cfg: EvalConfig) -> ABCModel:
11
+ """
12
+ Factory get model (default: ReactModel)
13
+ """
14
+ if cfg.plancraft.mode == "dummy":
15
+ return DummyModel(cfg)
16
+ elif cfg.plancraft.mode == "oracle":
17
+ return OracleModel(cfg)
18
+ elif cfg.plancraft.mode == "act":
19
+ return ActModel(cfg)
20
+ else:
21
+ return ReactModel(cfg)
@@ -0,0 +1,184 @@
1
+ import copy
2
+ import torch
3
+ from dotenv import load_dotenv
4
+
5
+ from plancraft.config import EvalConfig
6
+ from plancraft.environments.actions import (
7
+ NoOp,
8
+ StopAction,
9
+ SymbolicAction,
10
+ )
11
+ from plancraft.models.base import ABCModel, History
12
+ from plancraft.models.bbox_model import IntegratedBoundingBoxModel
13
+ from plancraft.models.few_shot_images import load_prompt_images
14
+ from plancraft.models.generators import (
15
+ OAMGenerator,
16
+ OpenAIGenerator,
17
+ TransformersGenerator,
18
+ )
19
+ from plancraft.models.prompts import get_prompt_example, get_system_prompt
20
+ from plancraft.models.utils import (
21
+ convert_observation_to_message,
22
+ parse_content_response,
23
+ )
24
+
25
+
26
+ load_dotenv()
27
+
28
+
29
+ class ActModel(ABCModel):
30
+ """
31
+ Model that does action without thinking step
32
+ """
33
+
34
+ def __init__(self, cfg: EvalConfig):
35
+ assert (
36
+ cfg.plancraft.environment.symbolic_action_space
37
+ ), "Real action space unsupported"
38
+ self.cfg = cfg
39
+ self.env_is_multimodal = not cfg.plancraft.environment.symbolic
40
+ self.use_maskrcnn = cfg.plancraft.use_maskrcnn
41
+ self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
42
+ self.use_text_inventory = cfg.plancraft.use_text_inventory
43
+ self.use_images = cfg.plancraft.use_images
44
+
45
+ self.bbox_model = None
46
+ if self.use_maskrcnn:
47
+ assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
48
+ self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
49
+ "gautierdag/plancraft-maskrcnn"
50
+ )
51
+ self.bbox_model.eval()
52
+ if torch.cuda.is_available():
53
+ self.bbox_model.cuda()
54
+ # MaskRCNN is not multimodal model but a separate model
55
+
56
+ self.few_shot = cfg.plancraft.few_shot
57
+ self.use_system_prompt = cfg.plancraft.system_prompt
58
+ self.max_invalid_actions = 3
59
+
60
+ # underlying language model
61
+ if "gpt-4o" in cfg.plancraft.model:
62
+ self.use_multimodal_content_format = True
63
+ self.llm = OpenAIGenerator(
64
+ use_images=self.use_images, model_name=cfg.plancraft.model
65
+ )
66
+ elif "oam" in cfg.plancraft.model:
67
+ self.llm = OAMGenerator(model_name=cfg.plancraft.model)
68
+ else:
69
+ # model is transformers based
70
+ self.llm = TransformersGenerator(
71
+ model_name=cfg.plancraft.model,
72
+ tokenizer_name=cfg.plancraft.tokenizer,
73
+ quantize=cfg.plancraft.quantize,
74
+ use_hot_cache=cfg.plancraft.hot_cache,
75
+ adapter_name=cfg.plancraft.adapter,
76
+ )
77
+
78
+ self.prompt_images = []
79
+
80
+ self.valid_actions = cfg.plancraft.valid_actions
81
+ self.system_prompt_text = get_system_prompt(self.valid_actions)
82
+
83
+ examples = get_prompt_example(
84
+ self.valid_actions,
85
+ use_text_inventory=self.use_text_inventory,
86
+ use_multimodal_content_format=self.use_multimodal_content_format,
87
+ use_images=self.use_images,
88
+ )
89
+ if self.env_is_multimodal and self.use_images:
90
+ self.prompt_images = load_prompt_images()
91
+
92
+ if self.use_multimodal_content_format:
93
+ self.system_prompt = {
94
+ "role": "system",
95
+ "content": [
96
+ {"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
97
+ ],
98
+ }
99
+ else:
100
+ self.system_prompt = {
101
+ "role": "system",
102
+ "content": copy.deepcopy(self.system_prompt_text),
103
+ }
104
+
105
+ if not self.few_shot:
106
+ examples = []
107
+ if not self.use_system_prompt:
108
+ self.system_prompt = None
109
+
110
+ self.history = History(
111
+ initial_dialogue=examples,
112
+ use_multimodal_content_format=self.use_multimodal_content_format,
113
+ )
114
+
115
+ self.max_messages_window = cfg.plancraft.max_message_window
116
+ self.kv_cache = None
117
+
118
+ def reset_history(
119
+ self,
120
+ objective: str,
121
+ ):
122
+ examples = []
123
+ if self.few_shot:
124
+ examples = get_prompt_example(
125
+ self.valid_actions,
126
+ use_text_inventory=self.use_text_inventory,
127
+ use_multimodal_content_format=self.use_multimodal_content_format,
128
+ use_images=self.use_images,
129
+ )
130
+
131
+ self.history.reset(objective=objective, initial_dialogue=examples)
132
+ self.llm.reset()
133
+
134
+ def step(self, observation: dict) -> SymbolicAction | StopAction:
135
+ self.history.add_observation_to_history(observation)
136
+
137
+ # add observation to history
138
+ observation_message = convert_observation_to_message(
139
+ observation,
140
+ objective=self.history.objective,
141
+ bbox_model=self.bbox_model,
142
+ oam_model="oam" in self.llm.model_name,
143
+ use_text_inventory=self.use_text_inventory,
144
+ use_multimodal_content_format=self.use_multimodal_content_format,
145
+ use_images=self.use_images,
146
+ )
147
+ self.history.add_message_to_history(content=observation_message, role="user")
148
+
149
+ # Iterate until valid action
150
+ i = 0
151
+ while i < self.max_invalid_actions:
152
+ # add observation to history
153
+ message_window, image_window = self.llm.prepare_messages(
154
+ history=self.history,
155
+ max_messages_window=self.max_messages_window,
156
+ system_prompt=self.system_prompt,
157
+ prompt_images=self.prompt_images,
158
+ )
159
+ action_messages, action_token_used = self.llm.generate_unconstrained(
160
+ batch_messages=[message_window],
161
+ images=[image_window],
162
+ )
163
+ self.history.tokens_used += action_token_used
164
+
165
+ action_message = action_messages[0].split("\n")[0].strip()
166
+
167
+ self.history.add_message_to_history(
168
+ content=action_message, role="assistant"
169
+ )
170
+ response = parse_content_response(
171
+ action_message, valid_actions=self.valid_actions
172
+ )
173
+ if not isinstance(response, str):
174
+ # valid action
175
+ self.history.add_action_to_history(response)
176
+ return response
177
+
178
+ self.history.add_message_to_history(
179
+ content=response,
180
+ )
181
+ i += 1
182
+
183
+ # if no action is found after max_invalid_actions, default to useless move action
184
+ return NoOp()