plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -316
- environments/env_symbolic.py +0 -212
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -480
- models/oam.py +0 -283
- models/oracle.py +0 -265
- models/prompts.py +0 -158
- models/react.py +0 -93
- models/utils.py +0 -289
- plancraft-0.1.1.dist-info/METADATA +0 -74
- plancraft-0.1.1.dist-info/RECORD +0 -26
- plancraft-0.1.1.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
environments/sampler.py
DELETED
@@ -1,224 +0,0 @@
|
|
1
|
-
import math
|
2
|
-
import random
|
3
|
-
from collections import Counter
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
from plancraft.environments.items import all_data, ALL_ITEMS
|
7
|
-
from plancraft.environments.recipes import RECIPES
|
8
|
-
from plancraft.environments.planner import optimal_planner, get_ancestors
|
9
|
-
|
10
|
-
|
11
|
-
MAX_STACK_SIZE = {}
|
12
|
-
for data_item in all_data["items"]:
|
13
|
-
if data_item["stackable"]:
|
14
|
-
MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
|
15
|
-
else:
|
16
|
-
MAX_STACK_SIZE[data_item["type"]] = 1
|
17
|
-
|
18
|
-
|
19
|
-
def sample_distractors(
|
20
|
-
exclude_set: set = None, num_distractors: int = 16
|
21
|
-
) -> dict[str, int]:
|
22
|
-
distractors = {}
|
23
|
-
while len(distractors) < num_distractors:
|
24
|
-
item = random.choice(ALL_ITEMS)
|
25
|
-
if exclude_set is not None and item in exclude_set:
|
26
|
-
continue
|
27
|
-
count = random.randint(1, MAX_STACK_SIZE[item])
|
28
|
-
distractors[item] = count
|
29
|
-
return distractors
|
30
|
-
|
31
|
-
|
32
|
-
def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
|
33
|
-
# slots available outside of crafting interface
|
34
|
-
available_slots = list(range(10, 46))
|
35
|
-
random.shuffle(available_slots)
|
36
|
-
inventory_list = []
|
37
|
-
|
38
|
-
for item, total_count in inventory.items():
|
39
|
-
while total_count > 0:
|
40
|
-
if len(available_slots) == 0:
|
41
|
-
print("Not enough slots available")
|
42
|
-
break
|
43
|
-
slot = available_slots.pop()
|
44
|
-
count_in_slot = min(total_count, MAX_STACK_SIZE[item])
|
45
|
-
inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
|
46
|
-
total_count -= count_in_slot
|
47
|
-
|
48
|
-
return inventory_list
|
49
|
-
|
50
|
-
|
51
|
-
def sample_recipes(
|
52
|
-
target: str,
|
53
|
-
overall_exclude_set: set,
|
54
|
-
target_count: int = 1,
|
55
|
-
current_depth=0,
|
56
|
-
max_depth=20,
|
57
|
-
) -> tuple[set, set]:
|
58
|
-
# stop if the depth is too high
|
59
|
-
if current_depth > max_depth:
|
60
|
-
return {}, overall_exclude_set
|
61
|
-
|
62
|
-
# get all the recipes that can craft the target
|
63
|
-
overall_exclude_set.update([target])
|
64
|
-
local_exclude_set = set()
|
65
|
-
random_recipes = []
|
66
|
-
for r in RECIPES[target]:
|
67
|
-
recipe_inputs, exclude_set = r.sample_inputs()
|
68
|
-
# if inputs are already in the exclude set, skip this recipe (ensures no cycle)
|
69
|
-
if exclude_set.intersection(overall_exclude_set):
|
70
|
-
return {}, overall_exclude_set
|
71
|
-
local_exclude_set.update(exclude_set)
|
72
|
-
random_recipes.append((r, recipe_inputs))
|
73
|
-
|
74
|
-
overall_exclude_set |= local_exclude_set
|
75
|
-
|
76
|
-
# no recipes found
|
77
|
-
if len(random_recipes) == 0:
|
78
|
-
return {}, overall_exclude_set
|
79
|
-
|
80
|
-
# sample a random recipe
|
81
|
-
random_recipe = random.choice(random_recipes)
|
82
|
-
recipe, start_inputs = random_recipe
|
83
|
-
|
84
|
-
# recipe will not produce enough
|
85
|
-
if recipe.result.count < target_count:
|
86
|
-
# must do recipe X times
|
87
|
-
recipe_multiplier = math.ceil(target_count / recipe.result.count)
|
88
|
-
start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
|
89
|
-
|
90
|
-
for input_item in list(start_inputs.keys()):
|
91
|
-
# randomize depth first search to end early
|
92
|
-
if random.choice([True, False]):
|
93
|
-
continue
|
94
|
-
|
95
|
-
children_recipe_inputs, updated_exclude_set = sample_recipes(
|
96
|
-
target=input_item,
|
97
|
-
overall_exclude_set=overall_exclude_set,
|
98
|
-
target_count=start_inputs[input_item],
|
99
|
-
current_depth=current_depth + 1,
|
100
|
-
)
|
101
|
-
if len(children_recipe_inputs) == 0:
|
102
|
-
continue
|
103
|
-
|
104
|
-
overall_exclude_set.update(updated_exclude_set)
|
105
|
-
|
106
|
-
# remove recipe input item since we are crafting it
|
107
|
-
start_inputs[input_item] = 0
|
108
|
-
|
109
|
-
# add the children recipe inputs
|
110
|
-
for item, count in children_recipe_inputs.items():
|
111
|
-
start_inputs[item] = start_inputs.get(item, 0) + count
|
112
|
-
|
113
|
-
overall_exclude_set = overall_exclude_set - {None}
|
114
|
-
start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
|
115
|
-
|
116
|
-
return start_inputs, overall_exclude_set
|
117
|
-
|
118
|
-
|
119
|
-
def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
|
120
|
-
ancestors = set(get_ancestors(target))
|
121
|
-
possible_items = set(inventory.keys())
|
122
|
-
items_to_remove = list(ancestors.intersection(possible_items))
|
123
|
-
num_items = random.randint(1, len(items_to_remove))
|
124
|
-
for item in random.sample(items_to_remove, num_items):
|
125
|
-
count_to_remove = random.randint(1, inventory[item])
|
126
|
-
inventory[item] -= count_to_remove
|
127
|
-
if inventory[item] == 0:
|
128
|
-
del inventory[item]
|
129
|
-
return inventory
|
130
|
-
|
131
|
-
|
132
|
-
def construct_example(
|
133
|
-
target: str,
|
134
|
-
num_distractors: 16,
|
135
|
-
impossible=False,
|
136
|
-
) -> list[dict]:
|
137
|
-
"""
|
138
|
-
For a given target object, number of distractors, and impossible flag
|
139
|
-
Return a dictionary with the start inventory for the crafting task
|
140
|
-
|
141
|
-
The crafting task should be to craft the target, the inventory should contain
|
142
|
-
the resources required for the recipe to be crafted.
|
143
|
-
|
144
|
-
The number of distractors are how many random items should be added to the inventory.
|
145
|
-
|
146
|
-
If impossible is True, the target item should not be craftable with the given inventory.
|
147
|
-
"""
|
148
|
-
|
149
|
-
# sample the recipe
|
150
|
-
inventory, overall_exclude_set = sample_recipes(target, set())
|
151
|
-
if impossible:
|
152
|
-
# if impossible then remove one or more items from the inventory
|
153
|
-
inventory = remove_ancestor_items(
|
154
|
-
target,
|
155
|
-
inventory,
|
156
|
-
)
|
157
|
-
|
158
|
-
# add distractors to the inventory
|
159
|
-
distractors = sample_distractors(overall_exclude_set, num_distractors)
|
160
|
-
inventory.update(distractors)
|
161
|
-
|
162
|
-
optimal_path = optimal_planner(target, inventory)
|
163
|
-
# @TODO this is a hack to ensure that we don't have impossible examples
|
164
|
-
while optimal_path is not None and impossible:
|
165
|
-
inventory = remove_ancestor_items(target, inventory)
|
166
|
-
optimal_path = optimal_planner(target, inventory)
|
167
|
-
|
168
|
-
# assign to slots
|
169
|
-
inventory_list = assign_to_slots(inventory)
|
170
|
-
example = {
|
171
|
-
"inventory": inventory,
|
172
|
-
"slotted_inventory": inventory_list,
|
173
|
-
"target": target,
|
174
|
-
"num_distractors": num_distractors,
|
175
|
-
"impossible": impossible,
|
176
|
-
}
|
177
|
-
# either impossible and no path or not impossible and path exists
|
178
|
-
assert (impossible and optimal_path is None) or (
|
179
|
-
not impossible and optimal_path is not None
|
180
|
-
)
|
181
|
-
|
182
|
-
if not impossible:
|
183
|
-
example["optimal_path_length"] = len(optimal_path)
|
184
|
-
example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
|
185
|
-
example["inventory_trace"] = [i for (r, i) in optimal_path]
|
186
|
-
items_used, unique_items_used = calculate_stats_from_inventory_trace(
|
187
|
-
[example["inventory"]] + example["inventory_trace"]
|
188
|
-
)
|
189
|
-
example["items_used"] = items_used
|
190
|
-
example["unique_items_used"] = unique_items_used
|
191
|
-
|
192
|
-
return example
|
193
|
-
|
194
|
-
|
195
|
-
def calculate_stats_from_inventory_trace(
|
196
|
-
inventory_trace: list[dict],
|
197
|
-
) -> tuple[int, int]:
|
198
|
-
total_items_used = 0
|
199
|
-
total_unique_items_used = 0
|
200
|
-
|
201
|
-
for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
|
202
|
-
diff = Counter(a) - Counter(b)
|
203
|
-
total_items_used += sum(diff.values())
|
204
|
-
total_unique_items_used += len(diff)
|
205
|
-
|
206
|
-
return total_items_used, total_unique_items_used
|
207
|
-
|
208
|
-
|
209
|
-
def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
|
210
|
-
random.seed(seed)
|
211
|
-
np.random.seed(seed)
|
212
|
-
|
213
|
-
dataset = []
|
214
|
-
for recipe_target in list(RECIPES.keys()):
|
215
|
-
if len(RECIPES[recipe_target]) == 0:
|
216
|
-
continue
|
217
|
-
for num_distractors in distractors:
|
218
|
-
for _ in range(num_examples):
|
219
|
-
example = construct_example(
|
220
|
-
target=recipe_target, num_distractors=num_distractors
|
221
|
-
)
|
222
|
-
dataset.append(example)
|
223
|
-
|
224
|
-
return dataset
|
models/__init__.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
from plancraft.models.base import ABCModel
|
2
|
-
|
3
|
-
from plancraft.config import EvalConfig
|
4
|
-
from plancraft.models.dummy import DummyModel
|
5
|
-
from plancraft.models.react import ReactModel
|
6
|
-
from plancraft.models.oracle import OracleModel
|
7
|
-
from plancraft.models.act import ActModel
|
8
|
-
|
9
|
-
|
10
|
-
def get_model(cfg: EvalConfig) -> ABCModel:
|
11
|
-
"""
|
12
|
-
Factory get model (default: ReactModel)
|
13
|
-
"""
|
14
|
-
if cfg.plancraft.mode == "dummy":
|
15
|
-
return DummyModel(cfg)
|
16
|
-
elif cfg.plancraft.mode == "oracle":
|
17
|
-
return OracleModel(cfg)
|
18
|
-
elif cfg.plancraft.mode == "act":
|
19
|
-
return ActModel(cfg)
|
20
|
-
else:
|
21
|
-
return ReactModel(cfg)
|
models/act.py
DELETED
@@ -1,184 +0,0 @@
|
|
1
|
-
import copy
|
2
|
-
import torch
|
3
|
-
from dotenv import load_dotenv
|
4
|
-
|
5
|
-
from plancraft.config import EvalConfig
|
6
|
-
from plancraft.environments.actions import (
|
7
|
-
NoOp,
|
8
|
-
StopAction,
|
9
|
-
SymbolicAction,
|
10
|
-
)
|
11
|
-
from plancraft.models.base import ABCModel, History
|
12
|
-
from plancraft.models.bbox_model import IntegratedBoundingBoxModel
|
13
|
-
from plancraft.models.few_shot_images import load_prompt_images
|
14
|
-
from plancraft.models.generators import (
|
15
|
-
OAMGenerator,
|
16
|
-
OpenAIGenerator,
|
17
|
-
TransformersGenerator,
|
18
|
-
)
|
19
|
-
from plancraft.models.prompts import get_prompt_example, get_system_prompt
|
20
|
-
from plancraft.models.utils import (
|
21
|
-
convert_observation_to_message,
|
22
|
-
parse_content_response,
|
23
|
-
)
|
24
|
-
|
25
|
-
|
26
|
-
load_dotenv()
|
27
|
-
|
28
|
-
|
29
|
-
class ActModel(ABCModel):
|
30
|
-
"""
|
31
|
-
Model that does action without thinking step
|
32
|
-
"""
|
33
|
-
|
34
|
-
def __init__(self, cfg: EvalConfig):
|
35
|
-
assert (
|
36
|
-
cfg.plancraft.environment.symbolic_action_space
|
37
|
-
), "Real action space unsupported"
|
38
|
-
self.cfg = cfg
|
39
|
-
self.env_is_multimodal = not cfg.plancraft.environment.symbolic
|
40
|
-
self.use_maskrcnn = cfg.plancraft.use_maskrcnn
|
41
|
-
self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
|
42
|
-
self.use_text_inventory = cfg.plancraft.use_text_inventory
|
43
|
-
self.use_images = cfg.plancraft.use_images
|
44
|
-
|
45
|
-
self.bbox_model = None
|
46
|
-
if self.use_maskrcnn:
|
47
|
-
assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
|
48
|
-
self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
|
49
|
-
"gautierdag/plancraft-maskrcnn"
|
50
|
-
)
|
51
|
-
self.bbox_model.eval()
|
52
|
-
if torch.cuda.is_available():
|
53
|
-
self.bbox_model.cuda()
|
54
|
-
# MaskRCNN is not multimodal model but a separate model
|
55
|
-
|
56
|
-
self.few_shot = cfg.plancraft.few_shot
|
57
|
-
self.use_system_prompt = cfg.plancraft.system_prompt
|
58
|
-
self.max_invalid_actions = 3
|
59
|
-
|
60
|
-
# underlying language model
|
61
|
-
if "gpt-4o" in cfg.plancraft.model:
|
62
|
-
self.use_multimodal_content_format = True
|
63
|
-
self.llm = OpenAIGenerator(
|
64
|
-
use_images=self.use_images, model_name=cfg.plancraft.model
|
65
|
-
)
|
66
|
-
elif "oam" in cfg.plancraft.model:
|
67
|
-
self.llm = OAMGenerator(model_name=cfg.plancraft.model)
|
68
|
-
else:
|
69
|
-
# model is transformers based
|
70
|
-
self.llm = TransformersGenerator(
|
71
|
-
model_name=cfg.plancraft.model,
|
72
|
-
tokenizer_name=cfg.plancraft.tokenizer,
|
73
|
-
quantize=cfg.plancraft.quantize,
|
74
|
-
use_hot_cache=cfg.plancraft.hot_cache,
|
75
|
-
adapter_name=cfg.plancraft.adapter,
|
76
|
-
)
|
77
|
-
|
78
|
-
self.prompt_images = []
|
79
|
-
|
80
|
-
self.valid_actions = cfg.plancraft.valid_actions
|
81
|
-
self.system_prompt_text = get_system_prompt(self.valid_actions)
|
82
|
-
|
83
|
-
examples = get_prompt_example(
|
84
|
-
self.valid_actions,
|
85
|
-
use_text_inventory=self.use_text_inventory,
|
86
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
87
|
-
use_images=self.use_images,
|
88
|
-
)
|
89
|
-
if self.env_is_multimodal and self.use_images:
|
90
|
-
self.prompt_images = load_prompt_images()
|
91
|
-
|
92
|
-
if self.use_multimodal_content_format:
|
93
|
-
self.system_prompt = {
|
94
|
-
"role": "system",
|
95
|
-
"content": [
|
96
|
-
{"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
|
97
|
-
],
|
98
|
-
}
|
99
|
-
else:
|
100
|
-
self.system_prompt = {
|
101
|
-
"role": "system",
|
102
|
-
"content": copy.deepcopy(self.system_prompt_text),
|
103
|
-
}
|
104
|
-
|
105
|
-
if not self.few_shot:
|
106
|
-
examples = []
|
107
|
-
if not self.use_system_prompt:
|
108
|
-
self.system_prompt = None
|
109
|
-
|
110
|
-
self.history = History(
|
111
|
-
initial_dialogue=examples,
|
112
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
113
|
-
)
|
114
|
-
|
115
|
-
self.max_messages_window = cfg.plancraft.max_message_window
|
116
|
-
self.kv_cache = None
|
117
|
-
|
118
|
-
def reset_history(
|
119
|
-
self,
|
120
|
-
objective: str,
|
121
|
-
):
|
122
|
-
examples = []
|
123
|
-
if self.few_shot:
|
124
|
-
examples = get_prompt_example(
|
125
|
-
self.valid_actions,
|
126
|
-
use_text_inventory=self.use_text_inventory,
|
127
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
128
|
-
use_images=self.use_images,
|
129
|
-
)
|
130
|
-
|
131
|
-
self.history.reset(objective=objective, initial_dialogue=examples)
|
132
|
-
self.llm.reset()
|
133
|
-
|
134
|
-
def step(self, observation: dict) -> SymbolicAction | StopAction:
|
135
|
-
self.history.add_observation_to_history(observation)
|
136
|
-
|
137
|
-
# add observation to history
|
138
|
-
observation_message = convert_observation_to_message(
|
139
|
-
observation,
|
140
|
-
objective=self.history.objective,
|
141
|
-
bbox_model=self.bbox_model,
|
142
|
-
oam_model="oam" in self.llm.model_name,
|
143
|
-
use_text_inventory=self.use_text_inventory,
|
144
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
145
|
-
use_images=self.use_images,
|
146
|
-
)
|
147
|
-
self.history.add_message_to_history(content=observation_message, role="user")
|
148
|
-
|
149
|
-
# Iterate until valid action
|
150
|
-
i = 0
|
151
|
-
while i < self.max_invalid_actions:
|
152
|
-
# add observation to history
|
153
|
-
message_window, image_window = self.llm.prepare_messages(
|
154
|
-
history=self.history,
|
155
|
-
max_messages_window=self.max_messages_window,
|
156
|
-
system_prompt=self.system_prompt,
|
157
|
-
prompt_images=self.prompt_images,
|
158
|
-
)
|
159
|
-
action_messages, action_token_used = self.llm.generate_unconstrained(
|
160
|
-
batch_messages=[message_window],
|
161
|
-
images=[image_window],
|
162
|
-
)
|
163
|
-
self.history.tokens_used += action_token_used
|
164
|
-
|
165
|
-
action_message = action_messages[0].split("\n")[0].strip()
|
166
|
-
|
167
|
-
self.history.add_message_to_history(
|
168
|
-
content=action_message, role="assistant"
|
169
|
-
)
|
170
|
-
response = parse_content_response(
|
171
|
-
action_message, valid_actions=self.valid_actions
|
172
|
-
)
|
173
|
-
if not isinstance(response, str):
|
174
|
-
# valid action
|
175
|
-
self.history.add_action_to_history(response)
|
176
|
-
return response
|
177
|
-
|
178
|
-
self.history.add_message_to_history(
|
179
|
-
content=response,
|
180
|
-
)
|
181
|
-
i += 1
|
182
|
-
|
183
|
-
# if no action is found after max_invalid_actions, default to useless move action
|
184
|
-
return NoOp()
|
models/base.py
DELETED
@@ -1,152 +0,0 @@
|
|
1
|
-
import abc
|
2
|
-
|
3
|
-
from copy import copy
|
4
|
-
from collections import Counter
|
5
|
-
|
6
|
-
from plancraft.environments.actions import (
|
7
|
-
SymbolicMoveAction,
|
8
|
-
RealActionInteraction,
|
9
|
-
SymbolicSmeltAction,
|
10
|
-
)
|
11
|
-
|
12
|
-
|
13
|
-
class History:
|
14
|
-
def __init__(
|
15
|
-
self,
|
16
|
-
objective: str = "",
|
17
|
-
initial_dialogue: list[dict] = [],
|
18
|
-
use_multimodal_content_format=False,
|
19
|
-
):
|
20
|
-
self.dialogue_history = initial_dialogue
|
21
|
-
self.initial_dialogue_length = len(initial_dialogue)
|
22
|
-
self.action_history = []
|
23
|
-
self.inventory_history = []
|
24
|
-
self.inventory_counters = []
|
25
|
-
self.images = []
|
26
|
-
self.objective = objective
|
27
|
-
self.tokens_used = 0
|
28
|
-
self.use_multimodal_content_format = use_multimodal_content_format
|
29
|
-
|
30
|
-
def add_message_to_history(self, content: str | dict, role="user"):
|
31
|
-
if role == "assistant":
|
32
|
-
print(content)
|
33
|
-
|
34
|
-
if isinstance(content, dict):
|
35
|
-
assert "content" in content, "content key not found in message"
|
36
|
-
content["role"] = role
|
37
|
-
self.dialogue_history.append(content)
|
38
|
-
else:
|
39
|
-
# fix for listed content type
|
40
|
-
if self.use_multimodal_content_format:
|
41
|
-
return self.add_message_to_history(
|
42
|
-
content={
|
43
|
-
"content": [{"type": "text", "text": content}],
|
44
|
-
"role": role,
|
45
|
-
},
|
46
|
-
role=role,
|
47
|
-
)
|
48
|
-
else:
|
49
|
-
self.dialogue_history.append({"role": role, "content": content})
|
50
|
-
|
51
|
-
def add_action_to_history(
|
52
|
-
self, action: SymbolicSmeltAction | RealActionInteraction | SymbolicMoveAction
|
53
|
-
):
|
54
|
-
if action is None:
|
55
|
-
return
|
56
|
-
self.action_history.append(action.model_dump())
|
57
|
-
|
58
|
-
def add_inventory_to_history(self, inventory: list[dict[str, int]]):
|
59
|
-
self.inventory_history.append(inventory)
|
60
|
-
|
61
|
-
# count inventory
|
62
|
-
counter = Counter()
|
63
|
-
for item in inventory:
|
64
|
-
# ignore slot 0
|
65
|
-
if "slot" in item and item["slot"] == 0:
|
66
|
-
continue
|
67
|
-
if "index" in item and item["index"] == 0:
|
68
|
-
continue
|
69
|
-
counter[item["type"]] += item["quantity"]
|
70
|
-
|
71
|
-
self.inventory_counters.append(counter)
|
72
|
-
|
73
|
-
def add_image_to_history(self, image):
|
74
|
-
self.images.append(image)
|
75
|
-
|
76
|
-
def add_observation_to_history(self, observation: dict):
|
77
|
-
if observation is None:
|
78
|
-
return
|
79
|
-
if "inventory" in observation:
|
80
|
-
clean_inv = []
|
81
|
-
# remove empty slots
|
82
|
-
for item in observation["inventory"]:
|
83
|
-
if item["quantity"] > 0:
|
84
|
-
clean_inv.append(item)
|
85
|
-
self.add_inventory_to_history(clean_inv)
|
86
|
-
if "pov" in observation:
|
87
|
-
self.add_image_to_history(observation["pov"])
|
88
|
-
|
89
|
-
def __str__(self):
|
90
|
-
return str(self.dialogue_history)
|
91
|
-
|
92
|
-
def reset(self, objective: str = "", initial_dialogue: list[dict] = []):
|
93
|
-
self.dialogue_history = initial_dialogue
|
94
|
-
self.action_history = []
|
95
|
-
self.inventory_history = []
|
96
|
-
self.inventory_counters = []
|
97
|
-
self.images = []
|
98
|
-
self.objective = objective
|
99
|
-
|
100
|
-
def set_objective(self, objective: str):
|
101
|
-
self.objective = objective
|
102
|
-
|
103
|
-
def trace(self):
|
104
|
-
return {
|
105
|
-
"dialogue_history": copy(
|
106
|
-
self.dialogue_history[self.initial_dialogue_length :]
|
107
|
-
),
|
108
|
-
"action_history": copy(self.action_history),
|
109
|
-
"inventory_history": copy(self.inventory_history),
|
110
|
-
"objective": copy(self.objective),
|
111
|
-
"tokens_used": copy(self.tokens_used),
|
112
|
-
}
|
113
|
-
|
114
|
-
@property
|
115
|
-
def num_steps(self):
|
116
|
-
return len(self.action_history)
|
117
|
-
|
118
|
-
def check_stuck(self, max_steps_no_change: int = 10) -> bool:
|
119
|
-
"""
|
120
|
-
If inventory content does not change for max_steps_no_change steps
|
121
|
-
the agent is considered stuck.
|
122
|
-
|
123
|
-
With N=10, the oracle solver can still solve 100% of the examples
|
124
|
-
"""
|
125
|
-
if len(self.inventory_counters) <= max_steps_no_change:
|
126
|
-
return False
|
127
|
-
|
128
|
-
return all(
|
129
|
-
c == self.inventory_counters[-max_steps_no_change - 1]
|
130
|
-
for c in self.inventory_counters[-max_steps_no_change - 1 :]
|
131
|
-
)
|
132
|
-
|
133
|
-
|
134
|
-
class ABCModel(abc.ABC):
|
135
|
-
"""
|
136
|
-
Model class must implement the following methods to work with evaluator
|
137
|
-
"""
|
138
|
-
|
139
|
-
@abc.abstractmethod
|
140
|
-
def step(
|
141
|
-
self, observation: list[dict]
|
142
|
-
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
143
|
-
"""
|
144
|
-
Model should output a valid action based on the 3 types available
|
145
|
-
|
146
|
-
Note this is a batch operation, so the model should return a list of actions
|
147
|
-
for each observation in the batch
|
148
|
-
"""
|
149
|
-
raise NotImplementedError()
|
150
|
-
|
151
|
-
def reset_history(self, objective: str = ""):
|
152
|
-
self.history.reset(objective=objective)
|