plancraft 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/environment/actions.py +34 -19
- plancraft/environment/search.py +11 -5
- plancraft/evaluator.py +134 -3
- plancraft/models/dummy.py +3 -0
- plancraft/models/oracle.py +9 -0
- {plancraft-0.3.16.dist-info → plancraft-0.3.18.dist-info}/METADATA +1 -1
- {plancraft-0.3.16.dist-info → plancraft-0.3.18.dist-info}/RECORD +9 -9
- {plancraft-0.3.16.dist-info → plancraft-0.3.18.dist-info}/WHEEL +0 -0
- {plancraft-0.3.16.dist-info → plancraft-0.3.18.dist-info}/licenses/LICENSE +0 -0
plancraft/environment/actions.py
CHANGED
@@ -208,10 +208,10 @@ class MoveActionHandler(ActionHandlerBase):
|
|
208
208
|
"""
|
209
209
|
Parse the raw model response to a MoveAction
|
210
210
|
"""
|
211
|
-
action_match = re.search(f"({self.action_name}):", generated_text)
|
212
|
-
if not action_match:
|
213
|
-
return
|
214
211
|
try:
|
212
|
+
action_match = re.search(f"({self.action_name}):", generated_text)
|
213
|
+
if not action_match:
|
214
|
+
return
|
215
215
|
slot_from = re.search(r" from (\[[ABCI]?\d+\])", generated_text).group(1)
|
216
216
|
slot_to = re.search(r" to (\[[ABCI]?\d+\])", generated_text).group(1)
|
217
217
|
quantity = re.search(r"with quantity (\d+)", generated_text).group(1)
|
@@ -221,8 +221,10 @@ class MoveActionHandler(ActionHandlerBase):
|
|
221
221
|
quantity=quantity,
|
222
222
|
)
|
223
223
|
return action
|
224
|
-
except AttributeError
|
225
|
-
return
|
224
|
+
except AttributeError:
|
225
|
+
return (
|
226
|
+
f"Format Error. Action be in the format: {self.prompt_format_example}"
|
227
|
+
)
|
226
228
|
|
227
229
|
|
228
230
|
class SmeltActionHandler(ActionHandlerBase):
|
@@ -242,10 +244,11 @@ class SmeltActionHandler(ActionHandlerBase):
|
|
242
244
|
"""
|
243
245
|
Parse the raw model response to a SmeltAction
|
244
246
|
"""
|
245
|
-
|
246
|
-
if not action_match:
|
247
|
-
return
|
247
|
+
|
248
248
|
try:
|
249
|
+
action_match = re.search(f"({self.action_name}):", generated_text)
|
250
|
+
if not action_match:
|
251
|
+
return
|
249
252
|
slot_from = re.search(r" from (\[[ABCI]?\d+\])", generated_text).group(1)
|
250
253
|
slot_to = re.search(r" to (\[[ABCI]?\d+\])", generated_text).group(1)
|
251
254
|
quantity = re.search(r"with quantity (\d+)", generated_text).group(1)
|
@@ -255,8 +258,10 @@ class SmeltActionHandler(ActionHandlerBase):
|
|
255
258
|
quantity=quantity,
|
256
259
|
)
|
257
260
|
return action
|
258
|
-
except AttributeError
|
259
|
-
return
|
261
|
+
except AttributeError:
|
262
|
+
return (
|
263
|
+
f"Format Error. Action be in the format: {self.prompt_format_example}"
|
264
|
+
)
|
260
265
|
|
261
266
|
|
262
267
|
class ImpossibleActionHandler(ActionHandlerBase):
|
@@ -276,11 +281,16 @@ class ImpossibleActionHandler(ActionHandlerBase):
|
|
276
281
|
"""
|
277
282
|
Parse the raw model response to a StopAction
|
278
283
|
"""
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
+
try:
|
285
|
+
action_match = re.search(f"({self.action_name}):", generated_text)
|
286
|
+
if not action_match:
|
287
|
+
return
|
288
|
+
reason = re.search(r"impossible: (.*)", generated_text).group(1)
|
289
|
+
return StopAction(reason=reason)
|
290
|
+
except AttributeError:
|
291
|
+
return (
|
292
|
+
f"Format Error. Action be in the format: {self.prompt_format_example}"
|
293
|
+
)
|
284
294
|
|
285
295
|
|
286
296
|
class ThinkActionHandler(ActionHandlerBase):
|
@@ -300,7 +310,12 @@ class ThinkActionHandler(ActionHandlerBase):
|
|
300
310
|
"""
|
301
311
|
Parse the raw model response to a ThinkAction
|
302
312
|
"""
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
313
|
+
try:
|
314
|
+
action_match = re.search(f"({self.action_name}):", generated_text)
|
315
|
+
if not action_match:
|
316
|
+
return
|
317
|
+
return "Ok"
|
318
|
+
except AttributeError:
|
319
|
+
return (
|
320
|
+
f"Format Error. Action be in the format: {self.prompt_format_example}"
|
321
|
+
)
|
plancraft/environment/search.py
CHANGED
@@ -46,8 +46,14 @@ class GoldSearchActionHandler(ActionHandlerBase):
|
|
46
46
|
"""
|
47
47
|
Parse the raw model response to a SearchAction
|
48
48
|
"""
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
try:
|
50
|
+
action_match = re.search(f"({self.action_name}):", generated_text)
|
51
|
+
if not action_match:
|
52
|
+
return
|
53
|
+
|
54
|
+
search_target = re.search(r"search: *(\w+)", generated_text).group(1)
|
55
|
+
return gold_search_recipe(search_target)
|
56
|
+
except AttributeError:
|
57
|
+
return (
|
58
|
+
f"Format Error. Action be in the format: {self.prompt_format_example}"
|
59
|
+
)
|
plancraft/evaluator.py
CHANGED
@@ -176,12 +176,14 @@ class Evaluator:
|
|
176
176
|
example: PlancraftExample,
|
177
177
|
model: PlancraftBaseModel,
|
178
178
|
) -> dict:
|
179
|
-
"""
|
180
|
-
|
179
|
+
"""
|
180
|
+
Given the loaded model and an example from Plancraft
|
181
|
+
run the episode until success or termination.
|
182
|
+
"""
|
181
183
|
|
182
184
|
# start environment
|
183
185
|
environment = PlancraftEnvironment(
|
184
|
-
inventory=example.slotted_inventory,
|
186
|
+
inventory=deepcopy(example.slotted_inventory),
|
185
187
|
resolution=self.resolution,
|
186
188
|
)
|
187
189
|
|
@@ -252,6 +254,135 @@ class Evaluator:
|
|
252
254
|
"images": history.images,
|
253
255
|
}
|
254
256
|
|
257
|
+
def batch_eval_examples(
|
258
|
+
self,
|
259
|
+
examples: list[PlancraftExample],
|
260
|
+
model,
|
261
|
+
) -> list:
|
262
|
+
# Initialize environments and histories
|
263
|
+
environments = [
|
264
|
+
PlancraftEnvironment(
|
265
|
+
inventory=deepcopy(examples[i].slotted_inventory),
|
266
|
+
resolution=self.resolution,
|
267
|
+
)
|
268
|
+
for i in range(len(examples))
|
269
|
+
]
|
270
|
+
|
271
|
+
histories = [
|
272
|
+
History(
|
273
|
+
actions=self.actions,
|
274
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
275
|
+
use_images=self.use_images,
|
276
|
+
use_text_inventory=self.use_text_inventory,
|
277
|
+
resolution=self.resolution,
|
278
|
+
few_shot=self.few_shot,
|
279
|
+
system_prompt=deepcopy(self.system_prompt),
|
280
|
+
prompt_examples=deepcopy(self.prompt_examples),
|
281
|
+
prompt_images=deepcopy(self.prompt_images),
|
282
|
+
)
|
283
|
+
for _ in range(len(examples))
|
284
|
+
]
|
285
|
+
|
286
|
+
# Track which environments are still active
|
287
|
+
active_mask = [True for _ in range(len(examples))]
|
288
|
+
results = [None for _ in range(len(examples))]
|
289
|
+
steps_taken = [0 for _ in range(len(examples))]
|
290
|
+
actions = [None for _ in range(len(examples))]
|
291
|
+
|
292
|
+
while any(active_mask) and all(steps < self.max_steps for steps in steps_taken):
|
293
|
+
# Get observations for all active environments
|
294
|
+
observations = []
|
295
|
+
active_indices = []
|
296
|
+
for i, (env, action, active) in enumerate(
|
297
|
+
zip(environments, actions, active_mask)
|
298
|
+
):
|
299
|
+
if not active:
|
300
|
+
continue
|
301
|
+
|
302
|
+
if isinstance(action, StopAction):
|
303
|
+
# Handle stop action
|
304
|
+
active_mask[i] = False
|
305
|
+
results[i] = {
|
306
|
+
"success": examples[i].impossible,
|
307
|
+
"recipe_type": examples[i].recipe_type,
|
308
|
+
"complexity": examples[i].complexity_split,
|
309
|
+
"number_of_steps": steps_taken[i],
|
310
|
+
"model_trace": histories[i].trace(),
|
311
|
+
"example_id": examples[i].id,
|
312
|
+
"images": histories[i].images,
|
313
|
+
}
|
314
|
+
logger.info("STOP")
|
315
|
+
continue
|
316
|
+
|
317
|
+
active_indices.append(i)
|
318
|
+
if isinstance(action, str):
|
319
|
+
# Handle message action
|
320
|
+
obs = env.step()
|
321
|
+
obs["target"] = examples[i].target
|
322
|
+
obs["message"] = action
|
323
|
+
else:
|
324
|
+
# Handle environment action
|
325
|
+
obs = env.step(action)
|
326
|
+
obs["target"] = examples[i].target
|
327
|
+
obs["message"] = self.convert_observation_to_message(
|
328
|
+
obs, model=model
|
329
|
+
)
|
330
|
+
|
331
|
+
# Check if done
|
332
|
+
if self.check_done(obs["inventory"], examples[i].target):
|
333
|
+
active_mask[i] = False
|
334
|
+
results[i] = {
|
335
|
+
"success": True,
|
336
|
+
"recipe_type": examples[i].recipe_type,
|
337
|
+
"complexity": examples[i].complexity_split,
|
338
|
+
"number_of_steps": steps_taken[i],
|
339
|
+
"model_trace": histories[i].trace(),
|
340
|
+
"example_id": examples[i].id,
|
341
|
+
"images": histories[i].images,
|
342
|
+
}
|
343
|
+
continue
|
344
|
+
|
345
|
+
observations.append(obs)
|
346
|
+
histories[i].add_observation_to_history(obs)
|
347
|
+
histories[i].add_message_to_history(content=obs["message"], role="user")
|
348
|
+
steps_taken[i] += 1
|
349
|
+
|
350
|
+
if not observations:
|
351
|
+
break
|
352
|
+
|
353
|
+
# Batch predict actions for active environments
|
354
|
+
active_histories = [histories[i] for i in active_indices]
|
355
|
+
raw_actions = model.batch_step(
|
356
|
+
observations, dialogue_histories=active_histories
|
357
|
+
)
|
358
|
+
|
359
|
+
# Process actions for each active environment
|
360
|
+
for idx, raw_action in zip(active_indices, raw_actions):
|
361
|
+
logger.info(f"{histories[idx].num_steps}, {raw_action}")
|
362
|
+
histories[idx].add_message_to_history(
|
363
|
+
content=raw_action, role="assistant"
|
364
|
+
)
|
365
|
+
actions[idx] = self.parse_raw_model_response(
|
366
|
+
raw_action,
|
367
|
+
observation=observations[active_indices.index(idx)],
|
368
|
+
history=histories[idx],
|
369
|
+
)
|
370
|
+
|
371
|
+
# Fill in results for environments that didn't finish
|
372
|
+
for i, result in enumerate(results):
|
373
|
+
if result is None:
|
374
|
+
results[i] = {
|
375
|
+
"success": False,
|
376
|
+
"recipe_type": examples[i].recipe_type,
|
377
|
+
"complexity": examples[i].complexity_split,
|
378
|
+
"number_of_steps": steps_taken[i],
|
379
|
+
"model_trace": histories[i].trace(),
|
380
|
+
"example_id": examples[i].id,
|
381
|
+
"images": histories[i].images,
|
382
|
+
}
|
383
|
+
|
384
|
+
return results
|
385
|
+
|
255
386
|
def eval_all_examples(self, model, progress_bar=False) -> list:
|
256
387
|
results = []
|
257
388
|
pbar = tqdm(
|
plancraft/models/dummy.py
CHANGED
@@ -40,3 +40,6 @@ class DummyModel(PlancraftBaseModel):
|
|
40
40
|
|
41
41
|
def step(self, observation: dict, **kwargs) -> str:
|
42
42
|
return str(self.random_select(observation))
|
43
|
+
|
44
|
+
def batch_step(self, observations: list[dict], **kwargs) -> list:
|
45
|
+
return [self.step(observation) for observation in observations]
|
plancraft/models/oracle.py
CHANGED
@@ -38,3 +38,12 @@ class OracleModel(PlancraftBaseModel):
|
|
38
38
|
|
39
39
|
action = self.subplans.pop(0)
|
40
40
|
return action
|
41
|
+
|
42
|
+
def batch_step(self, observations: list[dict], **kwargs) -> list:
|
43
|
+
# Need to fully isolate state between examples
|
44
|
+
actions = []
|
45
|
+
for observation in observations:
|
46
|
+
self.reset()
|
47
|
+
action = self.step(observation)
|
48
|
+
actions.append(action)
|
49
|
+
return actions
|
@@ -1,6 +1,6 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=v8itX8buduqTZdR39gtLwdhKGEnSX3rJv9Yd13EzNgQ,16395
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
5
|
plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
@@ -11,14 +11,14 @@ plancraft/data/val.json,sha256=IToAiaqUNQi_xhX1bzmInuskLaT7C2ryQjP-CZkzL24,13044
|
|
11
11
|
plancraft/data/val.small.easy.json,sha256=9zEmqepjXG2NIp88xnFqOCkwsUsku3HEwHoQGxgTr6U,190252
|
12
12
|
plancraft/data/val.small.json,sha256=76E9EFaljDQyAokg97e-IblvcOe6KbrdKkXvRxhhkgo,237653
|
13
13
|
plancraft/environment/__init__.py,sha256=XFsFny4lH195AwAmL-WeCaF9ZCMgc7IgXIwhQ8FTdgE,505
|
14
|
-
plancraft/environment/actions.py,sha256=
|
14
|
+
plancraft/environment/actions.py,sha256=VhPSRr0b1ySxb106TcBFdb3MdycxWwQGzqDnWQagm-8,10007
|
15
15
|
plancraft/environment/env.py,sha256=A4532st7JFBYBF_Nh0CEEi3ZTLJAeaB3t9PAIVSemj0,16390
|
16
16
|
plancraft/environment/items.py,sha256=Z9rhSyVDEoHF1pxRvhyiT94tyQJaWHi3wUHVcamz82o,221
|
17
17
|
plancraft/environment/planner.py,sha256=uIOJjIoyT_4pxeWeTKb8BkLJyKZG0-AMoEOkZs6Ua9A,19340
|
18
18
|
plancraft/environment/prompts.py,sha256=8QXclX0ygpL02uZichE1AVkbdn_0HGteD5bzo0FZGOU,6947
|
19
19
|
plancraft/environment/recipes.py,sha256=0vwzOU86eZmGN2EpZVSIvzxpx0AOBWNPxTtAOFBN2A0,19570
|
20
20
|
plancraft/environment/sampler.py,sha256=79hLpTU0ajvMPoBsvSe8tE88x31c8Vlczb3tJZJcau0,7441
|
21
|
-
plancraft/environment/search.py,sha256=
|
21
|
+
plancraft/environment/search.py,sha256=kk6t-MkpFGTL7I38GQ6H21BjW9qJLSNGMbJqvZhr1LE,2035
|
22
22
|
plancraft/environment/assets/constants.json,sha256=kyOIOh82CTTMMGEIS60k5k6M-6fkEmYDoGAnvi3Zx5k,1379016
|
23
23
|
plancraft/environment/assets/minecraft_font.ttf,sha256=AzoK9cgggXwjFPHtIO7uz-YaDrminl3nvB-VsaTvTAk,60992
|
24
24
|
plancraft/environment/assets/table.png,sha256=IKIViZKAPyR4FWnS0JP9AZ19vIEO3qoS5-YRGAO1ow8,5430
|
@@ -1915,12 +1915,12 @@ plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,
|
|
1915
1915
|
plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
|
1916
1916
|
plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
|
1917
1917
|
plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
|
1918
|
-
plancraft/models/dummy.py,sha256=
|
1918
|
+
plancraft/models/dummy.py,sha256=3Nsnw12s_n5mWMuxUTaPCuJIzPp0vLHWKE827iKY5o0,1391
|
1919
1919
|
plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
|
1920
|
-
plancraft/models/oracle.py,sha256=
|
1920
|
+
plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.18.dist-info/METADATA,sha256=p_Ln_3jx77ygBZG6yjuLhVs883PysUXUCi1sK67QvJs,11148
|
1924
|
+
plancraft-0.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.18.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|