plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,224 @@
1
+ import math
2
+ import random
3
+ from collections import Counter
4
+
5
+ import numpy as np
6
+ from plancraft.environments.items import all_data, ALL_ITEMS
7
+ from plancraft.environments.recipes import RECIPES
8
+ from plancraft.environments.planner import optimal_planner, get_ancestors
9
+
10
+
11
+ MAX_STACK_SIZE = {}
12
+ for data_item in all_data["items"]:
13
+ if data_item["stackable"]:
14
+ MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
15
+ else:
16
+ MAX_STACK_SIZE[data_item["type"]] = 1
17
+
18
+
19
+ def sample_distractors(
20
+ exclude_set: set = None, num_distractors: int = 16
21
+ ) -> dict[str, int]:
22
+ distractors = {}
23
+ while len(distractors) < num_distractors:
24
+ item = random.choice(ALL_ITEMS)
25
+ if exclude_set is not None and item in exclude_set:
26
+ continue
27
+ count = random.randint(1, MAX_STACK_SIZE[item])
28
+ distractors[item] = count
29
+ return distractors
30
+
31
+
32
+ def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
33
+ # slots available outside of crafting interface
34
+ available_slots = list(range(10, 46))
35
+ random.shuffle(available_slots)
36
+ inventory_list = []
37
+
38
+ for item, total_count in inventory.items():
39
+ while total_count > 0:
40
+ if len(available_slots) == 0:
41
+ print("Not enough slots available")
42
+ break
43
+ slot = available_slots.pop()
44
+ count_in_slot = min(total_count, MAX_STACK_SIZE[item])
45
+ inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
46
+ total_count -= count_in_slot
47
+
48
+ return inventory_list
49
+
50
+
51
+ def sample_recipes(
52
+ target: str,
53
+ overall_exclude_set: set,
54
+ target_count: int = 1,
55
+ current_depth=0,
56
+ max_depth=20,
57
+ ) -> tuple[set, set]:
58
+ # stop if the depth is too high
59
+ if current_depth > max_depth:
60
+ return {}, overall_exclude_set
61
+
62
+ # get all the recipes that can craft the target
63
+ overall_exclude_set.update([target])
64
+ local_exclude_set = set()
65
+ random_recipes = []
66
+ for r in RECIPES[target]:
67
+ recipe_inputs, exclude_set = r.sample_inputs()
68
+ # if inputs are already in the exclude set, skip this recipe (ensures no cycle)
69
+ if exclude_set.intersection(overall_exclude_set):
70
+ return {}, overall_exclude_set
71
+ local_exclude_set.update(exclude_set)
72
+ random_recipes.append((r, recipe_inputs))
73
+
74
+ overall_exclude_set |= local_exclude_set
75
+
76
+ # no recipes found
77
+ if len(random_recipes) == 0:
78
+ return {}, overall_exclude_set
79
+
80
+ # sample a random recipe
81
+ random_recipe = random.choice(random_recipes)
82
+ recipe, start_inputs = random_recipe
83
+
84
+ # recipe will not produce enough
85
+ if recipe.result.count < target_count:
86
+ # must do recipe X times
87
+ recipe_multiplier = math.ceil(target_count / recipe.result.count)
88
+ start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
89
+
90
+ for input_item in list(start_inputs.keys()):
91
+ # randomize depth first search to end early
92
+ if random.choice([True, False]):
93
+ continue
94
+
95
+ children_recipe_inputs, updated_exclude_set = sample_recipes(
96
+ target=input_item,
97
+ overall_exclude_set=overall_exclude_set,
98
+ target_count=start_inputs[input_item],
99
+ current_depth=current_depth + 1,
100
+ )
101
+ if len(children_recipe_inputs) == 0:
102
+ continue
103
+
104
+ overall_exclude_set.update(updated_exclude_set)
105
+
106
+ # remove recipe input item since we are crafting it
107
+ start_inputs[input_item] = 0
108
+
109
+ # add the children recipe inputs
110
+ for item, count in children_recipe_inputs.items():
111
+ start_inputs[item] = start_inputs.get(item, 0) + count
112
+
113
+ overall_exclude_set = overall_exclude_set - {None}
114
+ start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
115
+
116
+ return start_inputs, overall_exclude_set
117
+
118
+
119
+ def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
120
+ ancestors = set(get_ancestors(target))
121
+ possible_items = set(inventory.keys())
122
+ items_to_remove = list(ancestors.intersection(possible_items))
123
+ num_items = random.randint(1, len(items_to_remove))
124
+ for item in random.sample(items_to_remove, num_items):
125
+ count_to_remove = random.randint(1, inventory[item])
126
+ inventory[item] -= count_to_remove
127
+ if inventory[item] == 0:
128
+ del inventory[item]
129
+ return inventory
130
+
131
+
132
+ def construct_example(
133
+ target: str,
134
+ num_distractors: 16,
135
+ impossible=False,
136
+ ) -> list[dict]:
137
+ """
138
+ For a given target object, number of distractors, and impossible flag
139
+ Return a dictionary with the start inventory for the crafting task
140
+
141
+ The crafting task should be to craft the target, the inventory should contain
142
+ the resources required for the recipe to be crafted.
143
+
144
+ The number of distractors are how many random items should be added to the inventory.
145
+
146
+ If impossible is True, the target item should not be craftable with the given inventory.
147
+ """
148
+
149
+ # sample the recipe
150
+ inventory, overall_exclude_set = sample_recipes(target, set())
151
+ if impossible:
152
+ # if impossible then remove one or more items from the inventory
153
+ inventory = remove_ancestor_items(
154
+ target,
155
+ inventory,
156
+ )
157
+
158
+ # add distractors to the inventory
159
+ distractors = sample_distractors(overall_exclude_set, num_distractors)
160
+ inventory.update(distractors)
161
+
162
+ optimal_path = optimal_planner(target, inventory)
163
+ # @TODO this is a hack to ensure that we don't have impossible examples
164
+ while optimal_path is not None and impossible:
165
+ inventory = remove_ancestor_items(target, inventory)
166
+ optimal_path = optimal_planner(target, inventory)
167
+
168
+ # assign to slots
169
+ inventory_list = assign_to_slots(inventory)
170
+ example = {
171
+ "inventory": inventory,
172
+ "slotted_inventory": inventory_list,
173
+ "target": target,
174
+ "num_distractors": num_distractors,
175
+ "impossible": impossible,
176
+ }
177
+ # either impossible and no path or not impossible and path exists
178
+ assert (impossible and optimal_path is None) or (
179
+ not impossible and optimal_path is not None
180
+ )
181
+
182
+ if not impossible:
183
+ example["optimal_path_length"] = len(optimal_path)
184
+ example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
185
+ example["inventory_trace"] = [i for (r, i) in optimal_path]
186
+ items_used, unique_items_used = calculate_stats_from_inventory_trace(
187
+ [example["inventory"]] + example["inventory_trace"]
188
+ )
189
+ example["items_used"] = items_used
190
+ example["unique_items_used"] = unique_items_used
191
+
192
+ return example
193
+
194
+
195
+ def calculate_stats_from_inventory_trace(
196
+ inventory_trace: list[dict],
197
+ ) -> tuple[int, int]:
198
+ total_items_used = 0
199
+ total_unique_items_used = 0
200
+
201
+ for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
202
+ diff = Counter(a) - Counter(b)
203
+ total_items_used += sum(diff.values())
204
+ total_unique_items_used += len(diff)
205
+
206
+ return total_items_used, total_unique_items_used
207
+
208
+
209
+ def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
210
+ random.seed(seed)
211
+ np.random.seed(seed)
212
+
213
+ dataset = []
214
+ for recipe_target in list(RECIPES.keys()):
215
+ if len(RECIPES[recipe_target]) == 0:
216
+ continue
217
+ for num_distractors in distractors:
218
+ for _ in range(num_examples):
219
+ example = construct_example(
220
+ target=recipe_target, num_distractors=num_distractors
221
+ )
222
+ dataset.append(example)
223
+
224
+ return dataset
plancraft/evaluator.py ADDED
@@ -0,0 +1,273 @@
1
+ import json
2
+ import os
3
+ import random
4
+ import string
5
+ import time
6
+
7
+ import imageio
8
+ import pandas as pd
9
+ import torch
10
+ import wandb
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ from plancraft.config import EvalConfig, PlancraftExample
15
+ from plancraft.environments.actions import StopAction
16
+ from plancraft.environments.env_real import RealPlancraft
17
+ from plancraft.environments.env_symbolic import SymbolicPlancraft
18
+ from plancraft.models import get_model
19
+
20
+ wandb.require("core")
21
+
22
+
23
+ class Evaluator:
24
+ """
25
+ The evaluator class handles the environment loop and model interaction
26
+
27
+ The environment is created based on the configuration and the examples are loaded from the dataset.
28
+ """
29
+
30
+ def __init__(self, cfg: EvalConfig):
31
+ self.cfg = cfg
32
+ self.output_dir = (
33
+ f"{cfg.plancraft.output_dir}/{self.evaluator_name()}/{cfg.plancraft.split}"
34
+ )
35
+ self.generation_number = 0
36
+
37
+ self.examples: list[PlancraftExample] = self.load_dataset(cfg.plancraft.split)
38
+
39
+ self.environment = self.create_env(cfg)
40
+ self.model = get_model(cfg)
41
+
42
+ self.record_frames = not (cfg.plancraft.environment.symbolic)
43
+
44
+ # no_op action
45
+ self.no_op = self.environment.action_space.no_op()
46
+
47
+ def evaluator_name(self) -> str:
48
+ symb_str = "real"
49
+ if self.cfg.plancraft.environment.symbolic:
50
+ symb_str = "symb"
51
+
52
+ if self.cfg.plancraft.use_maskrcnn:
53
+ symb_str += "_mrcnn"
54
+
55
+ model_name = self.cfg.plancraft.model.split("/")[-1]
56
+ if self.cfg.plancraft.adapter != "":
57
+ model_name = self.cfg.plancraft.adapter.split("/")[-1]
58
+
59
+ mode = self.cfg.plancraft.mode
60
+ if mode in ["dummy", "oracle"]:
61
+ return f"{mode}_{symb_str}"
62
+
63
+ actions = "|".join(self.cfg.plancraft.valid_actions)
64
+ return f"{self.cfg.plancraft.mode}_{symb_str}_{model_name}_{actions}"
65
+
66
+ def save_results_dict(self, example: PlancraftExample, results_dict: dict):
67
+ output_dir = f"{self.output_dir}/{self.generation_number}"
68
+ os.makedirs(output_dir, exist_ok=True)
69
+ json_path = f"{output_dir}/{example.id}.json"
70
+ with open(json_path, "w") as f:
71
+ json.dump(results_dict, f, indent=4)
72
+ wandb.save(json_path, policy="now")
73
+
74
+ def save_images(self, example: PlancraftExample, frames: list):
75
+ if len(frames) == 0:
76
+ return
77
+ output_dir = f"{self.output_dir}/{self.generation_number}"
78
+ os.makedirs(output_dir, exist_ok=True)
79
+ imageio.mimsave(f"{output_dir}/{example.id}.gif", frames)
80
+ # upload to wandb
81
+ wandb.save(f"{output_dir}/{example.id}.gif", policy="now")
82
+
83
+ def load_results_dict(self, example: PlancraftExample) -> dict:
84
+ path = f"{self.output_dir}/{self.generation_number}/{example.id}.json"
85
+ if not os.path.exists(path) or not self.cfg.plancraft.resume:
86
+ return None
87
+ with open(path, "r") as f:
88
+ return json.load(f)
89
+
90
+ def create_env(self, cfg: EvalConfig) -> RealPlancraft | SymbolicPlancraft:
91
+ if cfg.plancraft.environment.symbolic:
92
+ return SymbolicPlancraft(inventory=[])
93
+ return RealPlancraft(
94
+ inventory=[],
95
+ symbolic_action_space=cfg.plancraft.environment.symbolic_action_space,
96
+ symbolic_observation_space=cfg.plancraft.environment.symbolic_observation_space,
97
+ preferred_spawn_biome=cfg.plancraft.environment.preferred_spawn_biome,
98
+ resolution=cfg.plancraft.environment.resolution,
99
+ )
100
+
101
+ def close(self):
102
+ self.environment.close()
103
+
104
+ def load_dataset(self, dataset_split: str) -> list[PlancraftExample]:
105
+ with open(f"data/{dataset_split}.json", "r") as f:
106
+ dataset = json.load(f)
107
+ return [PlancraftExample(**example) for example in dataset]
108
+
109
+ def reset(
110
+ self,
111
+ example: PlancraftExample,
112
+ ):
113
+ current_inventory = example.slotted_inventory
114
+ self.environment.fast_reset(new_inventory=current_inventory)
115
+ # do a no op to an initial observation
116
+ obs, _, _, _ = self.environment.step(self.no_op)
117
+ # assert that the inventory is correct
118
+ if "inventory" in obs:
119
+ for item in current_inventory:
120
+ slot = item["slot"]
121
+ if (
122
+ obs["inventory"][slot]["type"] != item["type"]
123
+ or obs["inventory"][slot]["quantity"] != item["quantity"]
124
+ ) and item["type"] != "air":
125
+ logger.warning(f"Inventory does not match expected for slot {slot}")
126
+ logger.warning(f"Expected {item}")
127
+ logger.warning(f"Got {obs['inventory'][slot]}")
128
+ # try again
129
+ self.reset(example)
130
+
131
+ objective = f"Craft an item of type: {example.target}"
132
+ self.model.reset_history(objective=objective)
133
+
134
+ def check_done(self, inventory: list[dict[str, int]], target: str):
135
+ """
136
+ Check that target object is obtained
137
+ """
138
+ for item in inventory:
139
+ if target == item["type"]:
140
+ # ensure item is taken out of crafting slot
141
+ if "slot" in item and item["slot"] != 0:
142
+ return True
143
+ if "index" in item and item["index"] != 0:
144
+ return True
145
+ return False
146
+
147
+ @torch.no_grad()
148
+ def eval_all_examples(self, progress_bar=False) -> list:
149
+ results = []
150
+ action = self.no_op.copy()
151
+
152
+ pbar = tqdm(
153
+ total=len(self.examples),
154
+ disable=not progress_bar,
155
+ )
156
+ correct = 0
157
+ count = 0
158
+
159
+ for example in self.examples:
160
+ if resume_result := self.load_results_dict(example):
161
+ pbar.update(self.cfg.plancraft.max_steps)
162
+ results.append(resume_result)
163
+ continue
164
+
165
+ success = False
166
+
167
+ self.reset(example)
168
+ action = self.no_op.copy()
169
+
170
+ while (
171
+ not self.model.history.check_stuck()
172
+ and self.model.history.num_steps < self.cfg.plancraft.max_steps
173
+ ):
174
+ # if the action is stop then we end the episode
175
+ if isinstance(action, StopAction):
176
+ # if the action is stop and task is impossible then success
177
+ # otherwise we should not have stopped
178
+ success = example.impossible
179
+ break
180
+
181
+ # step action
182
+ observation, _, _, _ = self.environment.step(action)
183
+
184
+ # check if the episode is done
185
+ success = self.check_done(observation["inventory"], example.target)
186
+ # exit if success
187
+ if success:
188
+ break
189
+
190
+ # predict next action
191
+ action = self.model.step(observation)
192
+
193
+ # save results and reset
194
+ result = {
195
+ "success": success,
196
+ "recipe_type": example.recipe_type,
197
+ "number_of_steps": self.model.history.num_steps,
198
+ "model_trace": self.model.history.trace(),
199
+ "example_id": example.id,
200
+ "impossible": example.impossible,
201
+ }
202
+ results.append(result)
203
+ self.save_results_dict(example, result)
204
+ self.save_images(example, self.model.history.images)
205
+
206
+ correct += int(result["success"])
207
+ count += 1
208
+
209
+ acc = correct / count
210
+ pbar.set_postfix(correct=correct, count=count, acc=acc)
211
+ pbar.update(1)
212
+
213
+ return results
214
+
215
+ def eval_all(self):
216
+ logger.info(
217
+ f"Running evaluation over {len(self.examples)} examples {self.cfg.plancraft.num_generations} times."
218
+ )
219
+ run_name = (
220
+ f"{self.evaluator_name()} {self.cfg.plancraft.split}".replace(" ", "_")
221
+ .replace(".", "_")
222
+ .strip()
223
+ )
224
+
225
+ for n in range(self.cfg.plancraft.num_generations):
226
+ logger.info(f"Generation {n+1}/{self.cfg.plancraft.num_generations}")
227
+ run_id = "".join(random.choices(string.ascii_lowercase, k=5))
228
+ generation_run_name = run_name + f"_{run_id}"
229
+
230
+ wandb.init(
231
+ name=generation_run_name,
232
+ project=self.cfg.wandb.project,
233
+ entity=self.cfg.wandb.entity,
234
+ mode=self.cfg.wandb.mode,
235
+ group=self.cfg.plancraft.model,
236
+ job_type=self.cfg.plancraft.mode,
237
+ config=self.cfg.model_dump(),
238
+ )
239
+ time_now = time.time()
240
+
241
+ results_list = self.eval_all_examples(progress_bar=True)
242
+
243
+ results_df = pd.DataFrame(results_list)
244
+
245
+ output = {
246
+ "avg_success_rate": results_df["success"].mean(),
247
+ "avg_number_of_steps": results_df["number_of_steps"].mean(),
248
+ "avg_num_tokens_used": results_df["model_trace"]
249
+ .apply(pd.Series)["tokens_used"]
250
+ .mean(),
251
+ }
252
+
253
+ # calculate success rate for each recipe type
254
+ recipe_types = results_df["recipe_type"].unique()
255
+ for recipe_type in recipe_types:
256
+ mask = results_df["recipe_type"] == recipe_type
257
+ success_rate = results_df[mask]["success"].mean()
258
+ output[f"{recipe_type}_success_rate"] = success_rate
259
+
260
+ time_elapsed = time.time() - time_now
261
+ logger.info(f"Time elapsed: {time_elapsed:.2f}s")
262
+
263
+ logger.info(output)
264
+ wandb.log(output)
265
+ table = wandb.Table(
266
+ dataframe=results_df[["success", "number_of_steps", "example_id"]]
267
+ )
268
+ wandb.log({"results": table})
269
+ wandb.finish()
270
+
271
+ self.generation_number += 1
272
+
273
+ logger.info("Done")
@@ -0,0 +1,21 @@
1
+ from plancraft.models.base import ABCModel
2
+
3
+ from plancraft.config import EvalConfig
4
+ from plancraft.models.dummy import DummyModel
5
+ from plancraft.models.react import ReactModel
6
+ from plancraft.models.oracle import OracleModel
7
+ from plancraft.models.act import ActModel
8
+
9
+
10
+ def get_model(cfg: EvalConfig) -> ABCModel:
11
+ """
12
+ Factory get model (default: ReactModel)
13
+ """
14
+ if cfg.plancraft.mode == "dummy":
15
+ return DummyModel(cfg)
16
+ elif cfg.plancraft.mode == "oracle":
17
+ return OracleModel(cfg)
18
+ elif cfg.plancraft.mode == "act":
19
+ return ActModel(cfg)
20
+ else:
21
+ return ReactModel(cfg)
@@ -0,0 +1,184 @@
1
+ import copy
2
+ import torch
3
+ from dotenv import load_dotenv
4
+
5
+ from plancraft.config import EvalConfig
6
+ from plancraft.environments.actions import (
7
+ NoOp,
8
+ StopAction,
9
+ SymbolicAction,
10
+ )
11
+ from plancraft.models.base import ABCModel, History
12
+ from plancraft.models.bbox_model import IntegratedBoundingBoxModel
13
+ from plancraft.models.few_shot_images import load_prompt_images
14
+ from plancraft.models.generators import (
15
+ OAMGenerator,
16
+ OpenAIGenerator,
17
+ TransformersGenerator,
18
+ )
19
+ from plancraft.models.prompts import get_prompt_example, get_system_prompt
20
+ from plancraft.models.utils import (
21
+ convert_observation_to_message,
22
+ parse_content_response,
23
+ )
24
+
25
+
26
+ load_dotenv()
27
+
28
+
29
+ class ActModel(ABCModel):
30
+ """
31
+ Model that does action without thinking step
32
+ """
33
+
34
+ def __init__(self, cfg: EvalConfig):
35
+ assert (
36
+ cfg.plancraft.environment.symbolic_action_space
37
+ ), "Real action space unsupported"
38
+ self.cfg = cfg
39
+ self.env_is_multimodal = not cfg.plancraft.environment.symbolic
40
+ self.use_maskrcnn = cfg.plancraft.use_maskrcnn
41
+ self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
42
+ self.use_text_inventory = cfg.plancraft.use_text_inventory
43
+ self.use_images = cfg.plancraft.use_images
44
+
45
+ self.bbox_model = None
46
+ if self.use_maskrcnn:
47
+ assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
48
+ self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
49
+ "gautierdag/plancraft-maskrcnn"
50
+ )
51
+ self.bbox_model.eval()
52
+ if torch.cuda.is_available():
53
+ self.bbox_model.cuda()
54
+ # MaskRCNN is not multimodal model but a separate model
55
+
56
+ self.few_shot = cfg.plancraft.few_shot
57
+ self.use_system_prompt = cfg.plancraft.system_prompt
58
+ self.max_invalid_actions = 3
59
+
60
+ # underlying language model
61
+ if "gpt-4o" in cfg.plancraft.model:
62
+ self.use_multimodal_content_format = True
63
+ self.llm = OpenAIGenerator(
64
+ use_images=self.use_images, model_name=cfg.plancraft.model
65
+ )
66
+ elif "oam" in cfg.plancraft.model:
67
+ self.llm = OAMGenerator(model_name=cfg.plancraft.model)
68
+ else:
69
+ # model is transformers based
70
+ self.llm = TransformersGenerator(
71
+ model_name=cfg.plancraft.model,
72
+ tokenizer_name=cfg.plancraft.tokenizer,
73
+ quantize=cfg.plancraft.quantize,
74
+ use_hot_cache=cfg.plancraft.hot_cache,
75
+ adapter_name=cfg.plancraft.adapter,
76
+ )
77
+
78
+ self.prompt_images = []
79
+
80
+ self.valid_actions = cfg.plancraft.valid_actions
81
+ self.system_prompt_text = get_system_prompt(self.valid_actions)
82
+
83
+ examples = get_prompt_example(
84
+ self.valid_actions,
85
+ use_text_inventory=self.use_text_inventory,
86
+ use_multimodal_content_format=self.use_multimodal_content_format,
87
+ use_images=self.use_images,
88
+ )
89
+ if self.env_is_multimodal and self.use_images:
90
+ self.prompt_images = load_prompt_images()
91
+
92
+ if self.use_multimodal_content_format:
93
+ self.system_prompt = {
94
+ "role": "system",
95
+ "content": [
96
+ {"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
97
+ ],
98
+ }
99
+ else:
100
+ self.system_prompt = {
101
+ "role": "system",
102
+ "content": copy.deepcopy(self.system_prompt_text),
103
+ }
104
+
105
+ if not self.few_shot:
106
+ examples = []
107
+ if not self.use_system_prompt:
108
+ self.system_prompt = None
109
+
110
+ self.history = History(
111
+ initial_dialogue=examples,
112
+ use_multimodal_content_format=self.use_multimodal_content_format,
113
+ )
114
+
115
+ self.max_messages_window = cfg.plancraft.max_message_window
116
+ self.kv_cache = None
117
+
118
+ def reset_history(
119
+ self,
120
+ objective: str,
121
+ ):
122
+ examples = []
123
+ if self.few_shot:
124
+ examples = get_prompt_example(
125
+ self.valid_actions,
126
+ use_text_inventory=self.use_text_inventory,
127
+ use_multimodal_content_format=self.use_multimodal_content_format,
128
+ use_images=self.use_images,
129
+ )
130
+
131
+ self.history.reset(objective=objective, initial_dialogue=examples)
132
+ self.llm.reset()
133
+
134
+ def step(self, observation: dict) -> SymbolicAction | StopAction:
135
+ self.history.add_observation_to_history(observation)
136
+
137
+ # add observation to history
138
+ observation_message = convert_observation_to_message(
139
+ observation,
140
+ objective=self.history.objective,
141
+ bbox_model=self.bbox_model,
142
+ oam_model="oam" in self.llm.model_name,
143
+ use_text_inventory=self.use_text_inventory,
144
+ use_multimodal_content_format=self.use_multimodal_content_format,
145
+ use_images=self.use_images,
146
+ )
147
+ self.history.add_message_to_history(content=observation_message, role="user")
148
+
149
+ # Iterate until valid action
150
+ i = 0
151
+ while i < self.max_invalid_actions:
152
+ # add observation to history
153
+ message_window, image_window = self.llm.prepare_messages(
154
+ history=self.history,
155
+ max_messages_window=self.max_messages_window,
156
+ system_prompt=self.system_prompt,
157
+ prompt_images=self.prompt_images,
158
+ )
159
+ action_messages, action_token_used = self.llm.generate_unconstrained(
160
+ batch_messages=[message_window],
161
+ images=[image_window],
162
+ )
163
+ self.history.tokens_used += action_token_used
164
+
165
+ action_message = action_messages[0].split("\n")[0].strip()
166
+
167
+ self.history.add_message_to_history(
168
+ content=action_message, role="assistant"
169
+ )
170
+ response = parse_content_response(
171
+ action_message, valid_actions=self.valid_actions
172
+ )
173
+ if not isinstance(response, str):
174
+ # valid action
175
+ self.history.add_action_to_history(response)
176
+ return response
177
+
178
+ self.history.add_message_to_history(
179
+ content=response,
180
+ )
181
+ i += 1
182
+
183
+ # if no action is found after max_invalid_actions, default to useless move action
184
+ return NoOp()