plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -316
- environments/env_symbolic.py +0 -212
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -480
- models/oam.py +0 -283
- models/oracle.py +0 -265
- models/prompts.py +0 -158
- models/react.py +0 -93
- models/utils.py +0 -289
- plancraft-0.1.1.dist-info/METADATA +0 -74
- plancraft-0.1.1.dist-info/RECORD +0 -26
- plancraft-0.1.1.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
environments/sampler.py
DELETED
@@ -1,224 +0,0 @@
|
|
1
|
-
import math
|
2
|
-
import random
|
3
|
-
from collections import Counter
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
from plancraft.environments.items import all_data, ALL_ITEMS
|
7
|
-
from plancraft.environments.recipes import RECIPES
|
8
|
-
from plancraft.environments.planner import optimal_planner, get_ancestors
|
9
|
-
|
10
|
-
|
11
|
-
MAX_STACK_SIZE = {}
|
12
|
-
for data_item in all_data["items"]:
|
13
|
-
if data_item["stackable"]:
|
14
|
-
MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
|
15
|
-
else:
|
16
|
-
MAX_STACK_SIZE[data_item["type"]] = 1
|
17
|
-
|
18
|
-
|
19
|
-
def sample_distractors(
|
20
|
-
exclude_set: set = None, num_distractors: int = 16
|
21
|
-
) -> dict[str, int]:
|
22
|
-
distractors = {}
|
23
|
-
while len(distractors) < num_distractors:
|
24
|
-
item = random.choice(ALL_ITEMS)
|
25
|
-
if exclude_set is not None and item in exclude_set:
|
26
|
-
continue
|
27
|
-
count = random.randint(1, MAX_STACK_SIZE[item])
|
28
|
-
distractors[item] = count
|
29
|
-
return distractors
|
30
|
-
|
31
|
-
|
32
|
-
def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
|
33
|
-
# slots available outside of crafting interface
|
34
|
-
available_slots = list(range(10, 46))
|
35
|
-
random.shuffle(available_slots)
|
36
|
-
inventory_list = []
|
37
|
-
|
38
|
-
for item, total_count in inventory.items():
|
39
|
-
while total_count > 0:
|
40
|
-
if len(available_slots) == 0:
|
41
|
-
print("Not enough slots available")
|
42
|
-
break
|
43
|
-
slot = available_slots.pop()
|
44
|
-
count_in_slot = min(total_count, MAX_STACK_SIZE[item])
|
45
|
-
inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
|
46
|
-
total_count -= count_in_slot
|
47
|
-
|
48
|
-
return inventory_list
|
49
|
-
|
50
|
-
|
51
|
-
def sample_recipes(
|
52
|
-
target: str,
|
53
|
-
overall_exclude_set: set,
|
54
|
-
target_count: int = 1,
|
55
|
-
current_depth=0,
|
56
|
-
max_depth=20,
|
57
|
-
) -> tuple[set, set]:
|
58
|
-
# stop if the depth is too high
|
59
|
-
if current_depth > max_depth:
|
60
|
-
return {}, overall_exclude_set
|
61
|
-
|
62
|
-
# get all the recipes that can craft the target
|
63
|
-
overall_exclude_set.update([target])
|
64
|
-
local_exclude_set = set()
|
65
|
-
random_recipes = []
|
66
|
-
for r in RECIPES[target]:
|
67
|
-
recipe_inputs, exclude_set = r.sample_inputs()
|
68
|
-
# if inputs are already in the exclude set, skip this recipe (ensures no cycle)
|
69
|
-
if exclude_set.intersection(overall_exclude_set):
|
70
|
-
return {}, overall_exclude_set
|
71
|
-
local_exclude_set.update(exclude_set)
|
72
|
-
random_recipes.append((r, recipe_inputs))
|
73
|
-
|
74
|
-
overall_exclude_set |= local_exclude_set
|
75
|
-
|
76
|
-
# no recipes found
|
77
|
-
if len(random_recipes) == 0:
|
78
|
-
return {}, overall_exclude_set
|
79
|
-
|
80
|
-
# sample a random recipe
|
81
|
-
random_recipe = random.choice(random_recipes)
|
82
|
-
recipe, start_inputs = random_recipe
|
83
|
-
|
84
|
-
# recipe will not produce enough
|
85
|
-
if recipe.result.count < target_count:
|
86
|
-
# must do recipe X times
|
87
|
-
recipe_multiplier = math.ceil(target_count / recipe.result.count)
|
88
|
-
start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
|
89
|
-
|
90
|
-
for input_item in list(start_inputs.keys()):
|
91
|
-
# randomize depth first search to end early
|
92
|
-
if random.choice([True, False]):
|
93
|
-
continue
|
94
|
-
|
95
|
-
children_recipe_inputs, updated_exclude_set = sample_recipes(
|
96
|
-
target=input_item,
|
97
|
-
overall_exclude_set=overall_exclude_set,
|
98
|
-
target_count=start_inputs[input_item],
|
99
|
-
current_depth=current_depth + 1,
|
100
|
-
)
|
101
|
-
if len(children_recipe_inputs) == 0:
|
102
|
-
continue
|
103
|
-
|
104
|
-
overall_exclude_set.update(updated_exclude_set)
|
105
|
-
|
106
|
-
# remove recipe input item since we are crafting it
|
107
|
-
start_inputs[input_item] = 0
|
108
|
-
|
109
|
-
# add the children recipe inputs
|
110
|
-
for item, count in children_recipe_inputs.items():
|
111
|
-
start_inputs[item] = start_inputs.get(item, 0) + count
|
112
|
-
|
113
|
-
overall_exclude_set = overall_exclude_set - {None}
|
114
|
-
start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
|
115
|
-
|
116
|
-
return start_inputs, overall_exclude_set
|
117
|
-
|
118
|
-
|
119
|
-
def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
|
120
|
-
ancestors = set(get_ancestors(target))
|
121
|
-
possible_items = set(inventory.keys())
|
122
|
-
items_to_remove = list(ancestors.intersection(possible_items))
|
123
|
-
num_items = random.randint(1, len(items_to_remove))
|
124
|
-
for item in random.sample(items_to_remove, num_items):
|
125
|
-
count_to_remove = random.randint(1, inventory[item])
|
126
|
-
inventory[item] -= count_to_remove
|
127
|
-
if inventory[item] == 0:
|
128
|
-
del inventory[item]
|
129
|
-
return inventory
|
130
|
-
|
131
|
-
|
132
|
-
def construct_example(
|
133
|
-
target: str,
|
134
|
-
num_distractors: 16,
|
135
|
-
impossible=False,
|
136
|
-
) -> list[dict]:
|
137
|
-
"""
|
138
|
-
For a given target object, number of distractors, and impossible flag
|
139
|
-
Return a dictionary with the start inventory for the crafting task
|
140
|
-
|
141
|
-
The crafting task should be to craft the target, the inventory should contain
|
142
|
-
the resources required for the recipe to be crafted.
|
143
|
-
|
144
|
-
The number of distractors are how many random items should be added to the inventory.
|
145
|
-
|
146
|
-
If impossible is True, the target item should not be craftable with the given inventory.
|
147
|
-
"""
|
148
|
-
|
149
|
-
# sample the recipe
|
150
|
-
inventory, overall_exclude_set = sample_recipes(target, set())
|
151
|
-
if impossible:
|
152
|
-
# if impossible then remove one or more items from the inventory
|
153
|
-
inventory = remove_ancestor_items(
|
154
|
-
target,
|
155
|
-
inventory,
|
156
|
-
)
|
157
|
-
|
158
|
-
# add distractors to the inventory
|
159
|
-
distractors = sample_distractors(overall_exclude_set, num_distractors)
|
160
|
-
inventory.update(distractors)
|
161
|
-
|
162
|
-
optimal_path = optimal_planner(target, inventory)
|
163
|
-
# @TODO this is a hack to ensure that we don't have impossible examples
|
164
|
-
while optimal_path is not None and impossible:
|
165
|
-
inventory = remove_ancestor_items(target, inventory)
|
166
|
-
optimal_path = optimal_planner(target, inventory)
|
167
|
-
|
168
|
-
# assign to slots
|
169
|
-
inventory_list = assign_to_slots(inventory)
|
170
|
-
example = {
|
171
|
-
"inventory": inventory,
|
172
|
-
"slotted_inventory": inventory_list,
|
173
|
-
"target": target,
|
174
|
-
"num_distractors": num_distractors,
|
175
|
-
"impossible": impossible,
|
176
|
-
}
|
177
|
-
# either impossible and no path or not impossible and path exists
|
178
|
-
assert (impossible and optimal_path is None) or (
|
179
|
-
not impossible and optimal_path is not None
|
180
|
-
)
|
181
|
-
|
182
|
-
if not impossible:
|
183
|
-
example["optimal_path_length"] = len(optimal_path)
|
184
|
-
example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
|
185
|
-
example["inventory_trace"] = [i for (r, i) in optimal_path]
|
186
|
-
items_used, unique_items_used = calculate_stats_from_inventory_trace(
|
187
|
-
[example["inventory"]] + example["inventory_trace"]
|
188
|
-
)
|
189
|
-
example["items_used"] = items_used
|
190
|
-
example["unique_items_used"] = unique_items_used
|
191
|
-
|
192
|
-
return example
|
193
|
-
|
194
|
-
|
195
|
-
def calculate_stats_from_inventory_trace(
|
196
|
-
inventory_trace: list[dict],
|
197
|
-
) -> tuple[int, int]:
|
198
|
-
total_items_used = 0
|
199
|
-
total_unique_items_used = 0
|
200
|
-
|
201
|
-
for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
|
202
|
-
diff = Counter(a) - Counter(b)
|
203
|
-
total_items_used += sum(diff.values())
|
204
|
-
total_unique_items_used += len(diff)
|
205
|
-
|
206
|
-
return total_items_used, total_unique_items_used
|
207
|
-
|
208
|
-
|
209
|
-
def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
|
210
|
-
random.seed(seed)
|
211
|
-
np.random.seed(seed)
|
212
|
-
|
213
|
-
dataset = []
|
214
|
-
for recipe_target in list(RECIPES.keys()):
|
215
|
-
if len(RECIPES[recipe_target]) == 0:
|
216
|
-
continue
|
217
|
-
for num_distractors in distractors:
|
218
|
-
for _ in range(num_examples):
|
219
|
-
example = construct_example(
|
220
|
-
target=recipe_target, num_distractors=num_distractors
|
221
|
-
)
|
222
|
-
dataset.append(example)
|
223
|
-
|
224
|
-
return dataset
|
models/__init__.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
from plancraft.models.base import ABCModel
|
2
|
-
|
3
|
-
from plancraft.config import EvalConfig
|
4
|
-
from plancraft.models.dummy import DummyModel
|
5
|
-
from plancraft.models.react import ReactModel
|
6
|
-
from plancraft.models.oracle import OracleModel
|
7
|
-
from plancraft.models.act import ActModel
|
8
|
-
|
9
|
-
|
10
|
-
def get_model(cfg: EvalConfig) -> ABCModel:
|
11
|
-
"""
|
12
|
-
Factory get model (default: ReactModel)
|
13
|
-
"""
|
14
|
-
if cfg.plancraft.mode == "dummy":
|
15
|
-
return DummyModel(cfg)
|
16
|
-
elif cfg.plancraft.mode == "oracle":
|
17
|
-
return OracleModel(cfg)
|
18
|
-
elif cfg.plancraft.mode == "act":
|
19
|
-
return ActModel(cfg)
|
20
|
-
else:
|
21
|
-
return ReactModel(cfg)
|
models/act.py
DELETED
@@ -1,184 +0,0 @@
|
|
1
|
-
import copy
|
2
|
-
import torch
|
3
|
-
from dotenv import load_dotenv
|
4
|
-
|
5
|
-
from plancraft.config import EvalConfig
|
6
|
-
from plancraft.environments.actions import (
|
7
|
-
NoOp,
|
8
|
-
StopAction,
|
9
|
-
SymbolicAction,
|
10
|
-
)
|
11
|
-
from plancraft.models.base import ABCModel, History
|
12
|
-
from plancraft.models.bbox_model import IntegratedBoundingBoxModel
|
13
|
-
from plancraft.models.few_shot_images import load_prompt_images
|
14
|
-
from plancraft.models.generators import (
|
15
|
-
OAMGenerator,
|
16
|
-
OpenAIGenerator,
|
17
|
-
TransformersGenerator,
|
18
|
-
)
|
19
|
-
from plancraft.models.prompts import get_prompt_example, get_system_prompt
|
20
|
-
from plancraft.models.utils import (
|
21
|
-
convert_observation_to_message,
|
22
|
-
parse_content_response,
|
23
|
-
)
|
24
|
-
|
25
|
-
|
26
|
-
load_dotenv()
|
27
|
-
|
28
|
-
|
29
|
-
class ActModel(ABCModel):
|
30
|
-
"""
|
31
|
-
Model that does action without thinking step
|
32
|
-
"""
|
33
|
-
|
34
|
-
def __init__(self, cfg: EvalConfig):
|
35
|
-
assert (
|
36
|
-
cfg.plancraft.environment.symbolic_action_space
|
37
|
-
), "Real action space unsupported"
|
38
|
-
self.cfg = cfg
|
39
|
-
self.env_is_multimodal = not cfg.plancraft.environment.symbolic
|
40
|
-
self.use_maskrcnn = cfg.plancraft.use_maskrcnn
|
41
|
-
self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
|
42
|
-
self.use_text_inventory = cfg.plancraft.use_text_inventory
|
43
|
-
self.use_images = cfg.plancraft.use_images
|
44
|
-
|
45
|
-
self.bbox_model = None
|
46
|
-
if self.use_maskrcnn:
|
47
|
-
assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
|
48
|
-
self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
|
49
|
-
"gautierdag/plancraft-maskrcnn"
|
50
|
-
)
|
51
|
-
self.bbox_model.eval()
|
52
|
-
if torch.cuda.is_available():
|
53
|
-
self.bbox_model.cuda()
|
54
|
-
# MaskRCNN is not multimodal model but a separate model
|
55
|
-
|
56
|
-
self.few_shot = cfg.plancraft.few_shot
|
57
|
-
self.use_system_prompt = cfg.plancraft.system_prompt
|
58
|
-
self.max_invalid_actions = 3
|
59
|
-
|
60
|
-
# underlying language model
|
61
|
-
if "gpt-4o" in cfg.plancraft.model:
|
62
|
-
self.use_multimodal_content_format = True
|
63
|
-
self.llm = OpenAIGenerator(
|
64
|
-
use_images=self.use_images, model_name=cfg.plancraft.model
|
65
|
-
)
|
66
|
-
elif "oam" in cfg.plancraft.model:
|
67
|
-
self.llm = OAMGenerator(model_name=cfg.plancraft.model)
|
68
|
-
else:
|
69
|
-
# model is transformers based
|
70
|
-
self.llm = TransformersGenerator(
|
71
|
-
model_name=cfg.plancraft.model,
|
72
|
-
tokenizer_name=cfg.plancraft.tokenizer,
|
73
|
-
quantize=cfg.plancraft.quantize,
|
74
|
-
use_hot_cache=cfg.plancraft.hot_cache,
|
75
|
-
adapter_name=cfg.plancraft.adapter,
|
76
|
-
)
|
77
|
-
|
78
|
-
self.prompt_images = []
|
79
|
-
|
80
|
-
self.valid_actions = cfg.plancraft.valid_actions
|
81
|
-
self.system_prompt_text = get_system_prompt(self.valid_actions)
|
82
|
-
|
83
|
-
examples = get_prompt_example(
|
84
|
-
self.valid_actions,
|
85
|
-
use_text_inventory=self.use_text_inventory,
|
86
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
87
|
-
use_images=self.use_images,
|
88
|
-
)
|
89
|
-
if self.env_is_multimodal and self.use_images:
|
90
|
-
self.prompt_images = load_prompt_images()
|
91
|
-
|
92
|
-
if self.use_multimodal_content_format:
|
93
|
-
self.system_prompt = {
|
94
|
-
"role": "system",
|
95
|
-
"content": [
|
96
|
-
{"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
|
97
|
-
],
|
98
|
-
}
|
99
|
-
else:
|
100
|
-
self.system_prompt = {
|
101
|
-
"role": "system",
|
102
|
-
"content": copy.deepcopy(self.system_prompt_text),
|
103
|
-
}
|
104
|
-
|
105
|
-
if not self.few_shot:
|
106
|
-
examples = []
|
107
|
-
if not self.use_system_prompt:
|
108
|
-
self.system_prompt = None
|
109
|
-
|
110
|
-
self.history = History(
|
111
|
-
initial_dialogue=examples,
|
112
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
113
|
-
)
|
114
|
-
|
115
|
-
self.max_messages_window = cfg.plancraft.max_message_window
|
116
|
-
self.kv_cache = None
|
117
|
-
|
118
|
-
def reset_history(
|
119
|
-
self,
|
120
|
-
objective: str,
|
121
|
-
):
|
122
|
-
examples = []
|
123
|
-
if self.few_shot:
|
124
|
-
examples = get_prompt_example(
|
125
|
-
self.valid_actions,
|
126
|
-
use_text_inventory=self.use_text_inventory,
|
127
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
128
|
-
use_images=self.use_images,
|
129
|
-
)
|
130
|
-
|
131
|
-
self.history.reset(objective=objective, initial_dialogue=examples)
|
132
|
-
self.llm.reset()
|
133
|
-
|
134
|
-
def step(self, observation: dict) -> SymbolicAction | StopAction:
|
135
|
-
self.history.add_observation_to_history(observation)
|
136
|
-
|
137
|
-
# add observation to history
|
138
|
-
observation_message = convert_observation_to_message(
|
139
|
-
observation,
|
140
|
-
objective=self.history.objective,
|
141
|
-
bbox_model=self.bbox_model,
|
142
|
-
oam_model="oam" in self.llm.model_name,
|
143
|
-
use_text_inventory=self.use_text_inventory,
|
144
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
145
|
-
use_images=self.use_images,
|
146
|
-
)
|
147
|
-
self.history.add_message_to_history(content=observation_message, role="user")
|
148
|
-
|
149
|
-
# Iterate until valid action
|
150
|
-
i = 0
|
151
|
-
while i < self.max_invalid_actions:
|
152
|
-
# add observation to history
|
153
|
-
message_window, image_window = self.llm.prepare_messages(
|
154
|
-
history=self.history,
|
155
|
-
max_messages_window=self.max_messages_window,
|
156
|
-
system_prompt=self.system_prompt,
|
157
|
-
prompt_images=self.prompt_images,
|
158
|
-
)
|
159
|
-
action_messages, action_token_used = self.llm.generate_unconstrained(
|
160
|
-
batch_messages=[message_window],
|
161
|
-
images=[image_window],
|
162
|
-
)
|
163
|
-
self.history.tokens_used += action_token_used
|
164
|
-
|
165
|
-
action_message = action_messages[0].split("\n")[0].strip()
|
166
|
-
|
167
|
-
self.history.add_message_to_history(
|
168
|
-
content=action_message, role="assistant"
|
169
|
-
)
|
170
|
-
response = parse_content_response(
|
171
|
-
action_message, valid_actions=self.valid_actions
|
172
|
-
)
|
173
|
-
if not isinstance(response, str):
|
174
|
-
# valid action
|
175
|
-
self.history.add_action_to_history(response)
|
176
|
-
return response
|
177
|
-
|
178
|
-
self.history.add_message_to_history(
|
179
|
-
content=response,
|
180
|
-
)
|
181
|
-
i += 1
|
182
|
-
|
183
|
-
# if no action is found after max_invalid_actions, default to useless move action
|
184
|
-
return NoOp()
|
models/base.py
DELETED
@@ -1,152 +0,0 @@
|
|
1
|
-
import abc
|
2
|
-
|
3
|
-
from copy import copy
|
4
|
-
from collections import Counter
|
5
|
-
|
6
|
-
from plancraft.environments.actions import (
|
7
|
-
SymbolicMoveAction,
|
8
|
-
RealActionInteraction,
|
9
|
-
SymbolicSmeltAction,
|
10
|
-
)
|
11
|
-
|
12
|
-
|
13
|
-
class History:
|
14
|
-
def __init__(
|
15
|
-
self,
|
16
|
-
objective: str = "",
|
17
|
-
initial_dialogue: list[dict] = [],
|
18
|
-
use_multimodal_content_format=False,
|
19
|
-
):
|
20
|
-
self.dialogue_history = initial_dialogue
|
21
|
-
self.initial_dialogue_length = len(initial_dialogue)
|
22
|
-
self.action_history = []
|
23
|
-
self.inventory_history = []
|
24
|
-
self.inventory_counters = []
|
25
|
-
self.images = []
|
26
|
-
self.objective = objective
|
27
|
-
self.tokens_used = 0
|
28
|
-
self.use_multimodal_content_format = use_multimodal_content_format
|
29
|
-
|
30
|
-
def add_message_to_history(self, content: str | dict, role="user"):
|
31
|
-
if role == "assistant":
|
32
|
-
print(content)
|
33
|
-
|
34
|
-
if isinstance(content, dict):
|
35
|
-
assert "content" in content, "content key not found in message"
|
36
|
-
content["role"] = role
|
37
|
-
self.dialogue_history.append(content)
|
38
|
-
else:
|
39
|
-
# fix for listed content type
|
40
|
-
if self.use_multimodal_content_format:
|
41
|
-
return self.add_message_to_history(
|
42
|
-
content={
|
43
|
-
"content": [{"type": "text", "text": content}],
|
44
|
-
"role": role,
|
45
|
-
},
|
46
|
-
role=role,
|
47
|
-
)
|
48
|
-
else:
|
49
|
-
self.dialogue_history.append({"role": role, "content": content})
|
50
|
-
|
51
|
-
def add_action_to_history(
|
52
|
-
self, action: SymbolicSmeltAction | RealActionInteraction | SymbolicMoveAction
|
53
|
-
):
|
54
|
-
if action is None:
|
55
|
-
return
|
56
|
-
self.action_history.append(action.model_dump())
|
57
|
-
|
58
|
-
def add_inventory_to_history(self, inventory: list[dict[str, int]]):
|
59
|
-
self.inventory_history.append(inventory)
|
60
|
-
|
61
|
-
# count inventory
|
62
|
-
counter = Counter()
|
63
|
-
for item in inventory:
|
64
|
-
# ignore slot 0
|
65
|
-
if "slot" in item and item["slot"] == 0:
|
66
|
-
continue
|
67
|
-
if "index" in item and item["index"] == 0:
|
68
|
-
continue
|
69
|
-
counter[item["type"]] += item["quantity"]
|
70
|
-
|
71
|
-
self.inventory_counters.append(counter)
|
72
|
-
|
73
|
-
def add_image_to_history(self, image):
|
74
|
-
self.images.append(image)
|
75
|
-
|
76
|
-
def add_observation_to_history(self, observation: dict):
|
77
|
-
if observation is None:
|
78
|
-
return
|
79
|
-
if "inventory" in observation:
|
80
|
-
clean_inv = []
|
81
|
-
# remove empty slots
|
82
|
-
for item in observation["inventory"]:
|
83
|
-
if item["quantity"] > 0:
|
84
|
-
clean_inv.append(item)
|
85
|
-
self.add_inventory_to_history(clean_inv)
|
86
|
-
if "pov" in observation:
|
87
|
-
self.add_image_to_history(observation["pov"])
|
88
|
-
|
89
|
-
def __str__(self):
|
90
|
-
return str(self.dialogue_history)
|
91
|
-
|
92
|
-
def reset(self, objective: str = "", initial_dialogue: list[dict] = []):
|
93
|
-
self.dialogue_history = initial_dialogue
|
94
|
-
self.action_history = []
|
95
|
-
self.inventory_history = []
|
96
|
-
self.inventory_counters = []
|
97
|
-
self.images = []
|
98
|
-
self.objective = objective
|
99
|
-
|
100
|
-
def set_objective(self, objective: str):
|
101
|
-
self.objective = objective
|
102
|
-
|
103
|
-
def trace(self):
|
104
|
-
return {
|
105
|
-
"dialogue_history": copy(
|
106
|
-
self.dialogue_history[self.initial_dialogue_length :]
|
107
|
-
),
|
108
|
-
"action_history": copy(self.action_history),
|
109
|
-
"inventory_history": copy(self.inventory_history),
|
110
|
-
"objective": copy(self.objective),
|
111
|
-
"tokens_used": copy(self.tokens_used),
|
112
|
-
}
|
113
|
-
|
114
|
-
@property
|
115
|
-
def num_steps(self):
|
116
|
-
return len(self.action_history)
|
117
|
-
|
118
|
-
def check_stuck(self, max_steps_no_change: int = 10) -> bool:
|
119
|
-
"""
|
120
|
-
If inventory content does not change for max_steps_no_change steps
|
121
|
-
the agent is considered stuck.
|
122
|
-
|
123
|
-
With N=10, the oracle solver can still solve 100% of the examples
|
124
|
-
"""
|
125
|
-
if len(self.inventory_counters) <= max_steps_no_change:
|
126
|
-
return False
|
127
|
-
|
128
|
-
return all(
|
129
|
-
c == self.inventory_counters[-max_steps_no_change - 1]
|
130
|
-
for c in self.inventory_counters[-max_steps_no_change - 1 :]
|
131
|
-
)
|
132
|
-
|
133
|
-
|
134
|
-
class ABCModel(abc.ABC):
|
135
|
-
"""
|
136
|
-
Model class must implement the following methods to work with evaluator
|
137
|
-
"""
|
138
|
-
|
139
|
-
@abc.abstractmethod
|
140
|
-
def step(
|
141
|
-
self, observation: list[dict]
|
142
|
-
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
143
|
-
"""
|
144
|
-
Model should output a valid action based on the 3 types available
|
145
|
-
|
146
|
-
Note this is a batch operation, so the model should return a list of actions
|
147
|
-
for each observation in the batch
|
148
|
-
"""
|
149
|
-
raise NotImplementedError()
|
150
|
-
|
151
|
-
def reset_history(self, objective: str = ""):
|
152
|
-
self.history.reset(objective=objective)
|