plancraft 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +61 -52
- plancraft/models/dummy.py +7 -2
- plancraft/utils.py +0 -16
- {plancraft-0.3.10.dist-info → plancraft-0.3.12.dist-info}/METADATA +1 -1
- {plancraft-0.3.10.dist-info → plancraft-0.3.12.dist-info}/RECORD +7 -7
- {plancraft-0.3.10.dist-info → plancraft-0.3.12.dist-info}/WHEEL +0 -0
- {plancraft-0.3.10.dist-info → plancraft-0.3.12.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
from typing import Optional
|
4
|
+
from copy import deepcopy
|
4
5
|
|
5
6
|
import imageio
|
6
7
|
from loguru import logger
|
@@ -38,7 +39,6 @@ class Evaluator:
|
|
38
39
|
def __init__(
|
39
40
|
self,
|
40
41
|
run_name: str,
|
41
|
-
model: PlancraftBaseModel,
|
42
42
|
actions: list[ActionHandlerBase] = [MoveActionHandler(), SmeltActionHandler()],
|
43
43
|
output_dir: str = "output",
|
44
44
|
split: str = "val.small",
|
@@ -61,6 +61,13 @@ class Evaluator:
|
|
61
61
|
self.use_fasterrcnn = use_fasterrcnn
|
62
62
|
self.max_steps = max_steps
|
63
63
|
self.resume = resume
|
64
|
+
self.resolution = resolution
|
65
|
+
|
66
|
+
# history args
|
67
|
+
self.system_prompt = system_prompt
|
68
|
+
self.prompt_examples = prompt_examples
|
69
|
+
self.prompt_images = prompt_images
|
70
|
+
self.few_shot = few_shot
|
64
71
|
|
65
72
|
self.output_dir = f"{output_dir}/{run_name}/{split}"
|
66
73
|
self.generation_number = 0
|
@@ -69,28 +76,6 @@ class Evaluator:
|
|
69
76
|
# load all examples
|
70
77
|
self.examples: list[PlancraftExample] = self.load_dataset(split)
|
71
78
|
|
72
|
-
# start environment
|
73
|
-
self.environment = PlancraftEnvironment(
|
74
|
-
inventory=[],
|
75
|
-
resolution=resolution,
|
76
|
-
)
|
77
|
-
|
78
|
-
# initialise history/dialogue tracking
|
79
|
-
self.history = History(
|
80
|
-
actions=actions,
|
81
|
-
use_multimodal_content_format=use_multimodal_content_format,
|
82
|
-
use_images=use_images,
|
83
|
-
use_text_inventory=use_text_inventory,
|
84
|
-
resolution=resolution,
|
85
|
-
few_shot=few_shot,
|
86
|
-
system_prompt=system_prompt,
|
87
|
-
prompt_examples=prompt_examples,
|
88
|
-
prompt_images=prompt_images,
|
89
|
-
)
|
90
|
-
|
91
|
-
# load model
|
92
|
-
self.model = model
|
93
|
-
|
94
79
|
def save_results_dict(self, example: PlancraftExample, results_dict: dict):
|
95
80
|
output_dir = f"{self.output_dir}/{self.generation_number}"
|
96
81
|
os.makedirs(output_dir, exist_ok=True)
|
@@ -124,14 +109,6 @@ class Evaluator:
|
|
124
109
|
dataset = json.load(f)
|
125
110
|
return [PlancraftExample(**example) for example in dataset]
|
126
111
|
|
127
|
-
def reset(
|
128
|
-
self,
|
129
|
-
example: PlancraftExample,
|
130
|
-
):
|
131
|
-
self.environment.reset(new_inventory=example.slotted_inventory)
|
132
|
-
self.model.reset()
|
133
|
-
self.history.reset()
|
134
|
-
|
135
112
|
def check_done(self, inventory: dict, target: str):
|
136
113
|
"""
|
137
114
|
Check that target object is obtained
|
@@ -142,14 +119,16 @@ class Evaluator:
|
|
142
119
|
return True
|
143
120
|
return False
|
144
121
|
|
145
|
-
def parse_raw_model_response(
|
122
|
+
def parse_raw_model_response(
|
123
|
+
self, generated_text: str, observation=None, history=None
|
124
|
+
) -> str:
|
146
125
|
"""
|
147
126
|
Given a message and set of action handlers, parse the content to return the action
|
148
127
|
or a message if the action is not valid/requires message response
|
149
128
|
"""
|
150
129
|
for handler in self.actions:
|
151
130
|
match_output = handler.match(
|
152
|
-
generated_text, observation=observation, history=
|
131
|
+
generated_text, observation=observation, history=history
|
153
132
|
)
|
154
133
|
if match_output:
|
155
134
|
return match_output
|
@@ -159,6 +138,7 @@ class Evaluator:
|
|
159
138
|
def convert_observation_to_message(
|
160
139
|
self,
|
161
140
|
observation: dict,
|
141
|
+
model: PlancraftBaseModel = None,
|
162
142
|
) -> str | dict:
|
163
143
|
"""
|
164
144
|
Convert an environment observation to the message format used by an LLM chat model
|
@@ -170,8 +150,9 @@ class Evaluator:
|
|
170
150
|
- use_images: bool - Whether to append an image to the message content - must be used with use_multimodal_content_format.
|
171
151
|
"""
|
172
152
|
if self.use_fasterrcnn:
|
153
|
+
assert model is not None, "Model must be provided to convert image to text"
|
173
154
|
# convert image to inventory using fasterrcnn
|
174
|
-
inventory =
|
155
|
+
inventory = model.bbox_model.get_inventory(observation["image"].copy())
|
175
156
|
text_content = target_and_inventory_to_text_obs(
|
176
157
|
observation["target"], inventory
|
177
158
|
)
|
@@ -190,15 +171,38 @@ class Evaluator:
|
|
190
171
|
content_list.append({"type": "image"})
|
191
172
|
return {"content": content_list}
|
192
173
|
|
193
|
-
def eval_example(
|
174
|
+
def eval_example(
|
175
|
+
self,
|
176
|
+
example: PlancraftExample,
|
177
|
+
model: PlancraftBaseModel,
|
178
|
+
) -> dict:
|
194
179
|
"""Given the loaded model and an example from Plancraft
|
195
180
|
run the episode until success or termination."""
|
181
|
+
|
182
|
+
# start environment
|
183
|
+
environment = PlancraftEnvironment(
|
184
|
+
inventory=example.slotted_inventory,
|
185
|
+
resolution=self.resolution,
|
186
|
+
)
|
187
|
+
|
188
|
+
# initialise history/dialogue tracking
|
189
|
+
history = History(
|
190
|
+
actions=self.actions,
|
191
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
192
|
+
use_images=self.use_images,
|
193
|
+
use_text_inventory=self.use_text_inventory,
|
194
|
+
resolution=self.resolution,
|
195
|
+
few_shot=self.few_shot,
|
196
|
+
system_prompt=deepcopy(self.system_prompt),
|
197
|
+
prompt_examples=deepcopy(self.prompt_examples),
|
198
|
+
prompt_images=deepcopy(self.prompt_images),
|
199
|
+
)
|
200
|
+
|
196
201
|
success = False
|
197
|
-
self.reset(example)
|
198
202
|
action = None
|
199
203
|
|
200
204
|
# run episode until stuck or until max steps is reached
|
201
|
-
while
|
205
|
+
while history.num_steps < self.max_steps:
|
202
206
|
# if the action is stop then we end the episode
|
203
207
|
if isinstance(action, StopAction):
|
204
208
|
# if the action is stop and task is impossible then success
|
@@ -207,16 +211,16 @@ class Evaluator:
|
|
207
211
|
break
|
208
212
|
# action is external tool then it is str
|
209
213
|
if isinstance(action, str):
|
210
|
-
observation =
|
214
|
+
observation = environment.step()
|
211
215
|
observation["target"] = example.target
|
212
216
|
observation["message"] = action
|
213
217
|
# action is environment action
|
214
218
|
else:
|
215
|
-
observation =
|
219
|
+
observation = environment.step(action)
|
216
220
|
# convert inventory observation to text message
|
217
221
|
observation["target"] = example.target
|
218
222
|
observation["message"] = self.convert_observation_to_message(
|
219
|
-
observation
|
223
|
+
observation, model=model
|
220
224
|
)
|
221
225
|
# check if the episode is done
|
222
226
|
success = self.check_done(observation["inventory"], example.target)
|
@@ -225,29 +229,30 @@ class Evaluator:
|
|
225
229
|
break
|
226
230
|
|
227
231
|
# add observation to history
|
228
|
-
|
232
|
+
history.add_observation_to_history(observation)
|
229
233
|
# add observation message to history
|
230
|
-
|
231
|
-
content=observation["message"], role="user"
|
232
|
-
)
|
234
|
+
history.add_message_to_history(content=observation["message"], role="user")
|
233
235
|
# predict next action
|
234
|
-
raw_action =
|
236
|
+
raw_action = model.step(observation, dialogue_history=history)
|
235
237
|
# add message to history
|
236
|
-
|
238
|
+
history.add_message_to_history(content=raw_action, role="assistant")
|
237
239
|
# parse the raw action
|
238
|
-
action = self.parse_raw_model_response(
|
240
|
+
action = self.parse_raw_model_response(
|
241
|
+
raw_action, observation=observation, history=history
|
242
|
+
)
|
239
243
|
|
240
244
|
# save results and reset
|
241
245
|
return {
|
242
246
|
"success": success,
|
243
247
|
"recipe_type": example.recipe_type,
|
244
248
|
"complexity": example.complexity_split,
|
245
|
-
"number_of_steps":
|
246
|
-
"model_trace":
|
249
|
+
"number_of_steps": history.num_steps,
|
250
|
+
"model_trace": history.trace(),
|
247
251
|
"example_id": example.id,
|
252
|
+
"images": history.images,
|
248
253
|
}
|
249
254
|
|
250
|
-
def eval_all_examples(self, progress_bar=False) -> list:
|
255
|
+
def eval_all_examples(self, model, progress_bar=False) -> list:
|
251
256
|
results = []
|
252
257
|
pbar = tqdm(
|
253
258
|
total=len(self.examples),
|
@@ -268,10 +273,14 @@ class Evaluator:
|
|
268
273
|
]:
|
269
274
|
continue
|
270
275
|
|
271
|
-
result = self.eval_example(example)
|
276
|
+
result = self.eval_example(example, model=model)
|
277
|
+
model.reset()
|
278
|
+
|
279
|
+
# save images and results
|
280
|
+
self.save_images(example, result["images"])
|
281
|
+
del result["images"]
|
272
282
|
results.append(result)
|
273
283
|
self.save_results_dict(example, result)
|
274
|
-
self.save_images(example, self.history.images)
|
275
284
|
|
276
285
|
correct += int(result["success"])
|
277
286
|
count += 1
|
plancraft/models/dummy.py
CHANGED
@@ -18,14 +18,19 @@ class DummyModel(PlancraftBaseModel):
|
|
18
18
|
pass
|
19
19
|
|
20
20
|
def random_select(self, observation):
|
21
|
-
# randomly pick an item from the inventory
|
21
|
+
# randomly pick an item that has quantity 1 from the inventory
|
22
22
|
item_indices = set()
|
23
23
|
for slot, item in observation["inventory"].items():
|
24
|
-
if item["quantity"]
|
24
|
+
if item["quantity"] == 1:
|
25
25
|
item_indices.add(slot)
|
26
26
|
all_slots_to = set(range(1, 46))
|
27
27
|
empty_slots = all_slots_to - item_indices
|
28
28
|
|
29
|
+
# if not item with quantity == 1, randomly pick any item
|
30
|
+
if len(item_indices) == 0:
|
31
|
+
item_indices = set(observation["inventory"].keys())
|
32
|
+
|
33
|
+
# move the item to a random empty slot
|
29
34
|
random_slot_from = random.choice(list(item_indices))
|
30
35
|
random_slot_to = random.choice(list(empty_slots))
|
31
36
|
|
plancraft/utils.py
CHANGED
@@ -44,8 +44,6 @@ class History:
|
|
44
44
|
self.resolution = resolution # low, medium, high
|
45
45
|
|
46
46
|
self.inventory_history = []
|
47
|
-
self.inventory_counters = []
|
48
|
-
|
49
47
|
self.tokens_used = 0
|
50
48
|
|
51
49
|
# use system prompt if provided
|
@@ -105,14 +103,6 @@ class History:
|
|
105
103
|
|
106
104
|
def add_inventory_to_history(self, inventory: dict):
|
107
105
|
self.inventory_history.append(inventory)
|
108
|
-
# count inventory
|
109
|
-
counter = Counter()
|
110
|
-
for slot, item in inventory.items():
|
111
|
-
# ignore slot 0
|
112
|
-
if slot == 0:
|
113
|
-
continue
|
114
|
-
counter[item["type"]] += item["quantity"]
|
115
|
-
self.inventory_counters.append(counter)
|
116
106
|
|
117
107
|
def add_image_to_history(self, image):
|
118
108
|
self.images.append(image)
|
@@ -121,11 +111,6 @@ class History:
|
|
121
111
|
if observation is None:
|
122
112
|
return
|
123
113
|
if "inventory" in observation:
|
124
|
-
# clean_inv = []
|
125
|
-
# remove empty slots
|
126
|
-
# for slot, item in observation["inventory"].items():
|
127
|
-
# if item["quantity"] > 0:
|
128
|
-
# clean_inv.append(item)
|
129
114
|
self.add_inventory_to_history(observation["inventory"])
|
130
115
|
if "image" in observation:
|
131
116
|
self.add_image_to_history(observation["image"])
|
@@ -140,7 +125,6 @@ class History:
|
|
140
125
|
self.initial_dialogue_length = len(self.dialogue_history)
|
141
126
|
|
142
127
|
self.inventory_history = []
|
143
|
-
self.inventory_counters = []
|
144
128
|
|
145
129
|
self.tokens_used = 0
|
146
130
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=OZ9-xRiCfMPYIVMHZj8UMU53HGWinp18Ilj5BNySioI,11119
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
|
-
plancraft/utils.py,sha256=
|
5
|
+
plancraft/utils.py,sha256=0Uq-3VE-bTRstalzKknBJ-ExWf8ec_Jrg4QNEk8bJ-o,5778
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
7
7
|
plancraft/data/test.small.easy.json,sha256=5NZEJ2PqIgmHQecJOIVQyM1D6GFKyJq7GVmgRudaqQk,189304
|
8
8
|
plancraft/data/test.small.json,sha256=eULAG1rdolRMXPrecV-7YoDIheKGyIT5MVpWdISV0wg,270089
|
@@ -1915,12 +1915,12 @@ plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,
|
|
1915
1915
|
plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
|
1916
1916
|
plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
|
1917
1917
|
plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
|
1918
|
-
plancraft/models/dummy.py,sha256=
|
1918
|
+
plancraft/models/dummy.py,sha256=856oEX6NquXSIIfQLTEFFeB8ib7VUUs5cB0TVHAiFvI,1248
|
1919
1919
|
plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
|
1920
1920
|
plancraft/models/oracle.py,sha256=jDCE6zVFvbwFpDzQZTkHIlRwMud1yMJ4LVIdfpt5ddU,8449
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.12.dist-info/METADATA,sha256=RJyvF0PV84-_p-7ijstEcYrwjKLqedRce_LL5zF5ihs,11148
|
1924
|
+
plancraft-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.12.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|