plancraft 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +139 -4
- plancraft/models/dummy.py +3 -0
- plancraft/models/oracle.py +9 -0
- {plancraft-0.3.17.dist-info → plancraft-0.3.19.dist-info}/METADATA +1 -1
- {plancraft-0.3.17.dist-info → plancraft-0.3.19.dist-info}/RECORD +7 -7
- {plancraft-0.3.17.dist-info → plancraft-0.3.19.dist-info}/WHEEL +0 -0
- {plancraft-0.3.17.dist-info → plancraft-0.3.19.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Optional
|
|
4
4
|
from copy import deepcopy
|
5
5
|
|
6
6
|
import imageio
|
7
|
-
from loguru import logger
|
8
7
|
from tqdm import tqdm
|
9
8
|
|
10
9
|
import wandb
|
@@ -176,12 +175,14 @@ class Evaluator:
|
|
176
175
|
example: PlancraftExample,
|
177
176
|
model: PlancraftBaseModel,
|
178
177
|
) -> dict:
|
179
|
-
"""
|
180
|
-
|
178
|
+
"""
|
179
|
+
Given the loaded model and an example from Plancraft
|
180
|
+
run the episode until success or termination.
|
181
|
+
"""
|
181
182
|
|
182
183
|
# start environment
|
183
184
|
environment = PlancraftEnvironment(
|
184
|
-
inventory=example.slotted_inventory,
|
185
|
+
inventory=deepcopy(example.slotted_inventory),
|
185
186
|
resolution=self.resolution,
|
186
187
|
)
|
187
188
|
|
@@ -252,6 +253,140 @@ class Evaluator:
|
|
252
253
|
"images": history.images,
|
253
254
|
}
|
254
255
|
|
256
|
+
def batch_eval_examples(
|
257
|
+
self,
|
258
|
+
examples: list[PlancraftExample],
|
259
|
+
model,
|
260
|
+
) -> list:
|
261
|
+
# Initialize environments and histories
|
262
|
+
environments = [
|
263
|
+
PlancraftEnvironment(
|
264
|
+
inventory=deepcopy(examples[i].slotted_inventory),
|
265
|
+
resolution=self.resolution,
|
266
|
+
)
|
267
|
+
for i in range(len(examples))
|
268
|
+
]
|
269
|
+
|
270
|
+
histories = [
|
271
|
+
History(
|
272
|
+
actions=self.actions,
|
273
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
274
|
+
use_images=self.use_images,
|
275
|
+
use_text_inventory=self.use_text_inventory,
|
276
|
+
resolution=self.resolution,
|
277
|
+
few_shot=self.few_shot,
|
278
|
+
system_prompt=deepcopy(self.system_prompt),
|
279
|
+
prompt_examples=deepcopy(self.prompt_examples),
|
280
|
+
prompt_images=deepcopy(self.prompt_images),
|
281
|
+
)
|
282
|
+
for _ in range(len(examples))
|
283
|
+
]
|
284
|
+
|
285
|
+
# Track which environments are still active
|
286
|
+
active_mask = [True for _ in range(len(examples))]
|
287
|
+
results = [None for _ in range(len(examples))]
|
288
|
+
steps_taken = [0 for _ in range(len(examples))]
|
289
|
+
actions = [None for _ in range(len(examples))]
|
290
|
+
|
291
|
+
while any(active_mask) and all(steps < self.max_steps for steps in steps_taken):
|
292
|
+
# Get observations for all active environments
|
293
|
+
observations = []
|
294
|
+
active_indices = []
|
295
|
+
obs_mapping = {} # Map active_indices to their observation index
|
296
|
+
|
297
|
+
for i, (env, action, active) in enumerate(
|
298
|
+
zip(environments, actions, active_mask)
|
299
|
+
):
|
300
|
+
if not active:
|
301
|
+
continue
|
302
|
+
|
303
|
+
if isinstance(action, StopAction):
|
304
|
+
# Handle stop action
|
305
|
+
active_mask[i] = False
|
306
|
+
results[i] = {
|
307
|
+
"success": examples[i].impossible,
|
308
|
+
"recipe_type": examples[i].recipe_type,
|
309
|
+
"complexity": examples[i].complexity_split,
|
310
|
+
"number_of_steps": steps_taken[i],
|
311
|
+
"model_trace": histories[i].trace(),
|
312
|
+
"example_id": examples[i].id,
|
313
|
+
"images": histories[i].images,
|
314
|
+
}
|
315
|
+
continue
|
316
|
+
|
317
|
+
active_indices.append(i)
|
318
|
+
if isinstance(action, str):
|
319
|
+
# Handle message action
|
320
|
+
obs = env.step()
|
321
|
+
obs["target"] = examples[i].target
|
322
|
+
obs["message"] = action
|
323
|
+
else:
|
324
|
+
# Handle environment action
|
325
|
+
obs = env.step(action)
|
326
|
+
obs["target"] = examples[i].target
|
327
|
+
obs["message"] = self.convert_observation_to_message(
|
328
|
+
obs, model=model
|
329
|
+
)
|
330
|
+
|
331
|
+
# Check if done
|
332
|
+
if self.check_done(obs["inventory"], examples[i].target):
|
333
|
+
active_mask[i] = False
|
334
|
+
results[i] = {
|
335
|
+
"success": True,
|
336
|
+
"recipe_type": examples[i].recipe_type,
|
337
|
+
"complexity": examples[i].complexity_split,
|
338
|
+
"number_of_steps": steps_taken[i],
|
339
|
+
"model_trace": histories[i].trace(),
|
340
|
+
"example_id": examples[i].id,
|
341
|
+
"images": histories[i].images,
|
342
|
+
}
|
343
|
+
continue
|
344
|
+
|
345
|
+
obs_mapping[i] = len(
|
346
|
+
observations
|
347
|
+
) # Map active index to observation index
|
348
|
+
observations.append(obs)
|
349
|
+
histories[i].add_observation_to_history(obs)
|
350
|
+
histories[i].add_message_to_history(content=obs["message"], role="user")
|
351
|
+
steps_taken[i] += 1
|
352
|
+
|
353
|
+
if not observations:
|
354
|
+
break
|
355
|
+
|
356
|
+
# Batch predict actions for active environments
|
357
|
+
active_histories = [histories[i] for i in active_indices]
|
358
|
+
raw_actions = model.batch_step(
|
359
|
+
observations, dialogue_histories=active_histories
|
360
|
+
)
|
361
|
+
|
362
|
+
# Process actions for each active environment
|
363
|
+
for idx, raw_action in zip(active_indices, raw_actions):
|
364
|
+
histories[idx].add_message_to_history(
|
365
|
+
content=raw_action, role="assistant"
|
366
|
+
)
|
367
|
+
actions[idx] = self.parse_raw_model_response(
|
368
|
+
raw_action,
|
369
|
+
observation=observations[
|
370
|
+
obs_mapping[idx]
|
371
|
+
], # Use mapping to get correct observation
|
372
|
+
history=histories[idx],
|
373
|
+
)
|
374
|
+
|
375
|
+
# Fill in results for environments that didn't finish
|
376
|
+
for i, result in enumerate(results):
|
377
|
+
if result is None:
|
378
|
+
results[i] = {
|
379
|
+
"success": False,
|
380
|
+
"recipe_type": examples[i].recipe_type,
|
381
|
+
"complexity": examples[i].complexity_split,
|
382
|
+
"number_of_steps": steps_taken[i],
|
383
|
+
"model_trace": histories[i].trace(),
|
384
|
+
"example_id": examples[i].id,
|
385
|
+
"images": histories[i].images,
|
386
|
+
}
|
387
|
+
|
388
|
+
return results
|
389
|
+
|
255
390
|
def eval_all_examples(self, model, progress_bar=False) -> list:
|
256
391
|
results = []
|
257
392
|
pbar = tqdm(
|
plancraft/models/dummy.py
CHANGED
@@ -40,3 +40,6 @@ class DummyModel(PlancraftBaseModel):
|
|
40
40
|
|
41
41
|
def step(self, observation: dict, **kwargs) -> str:
|
42
42
|
return str(self.random_select(observation))
|
43
|
+
|
44
|
+
def batch_step(self, observations: list[dict], **kwargs) -> list:
|
45
|
+
return [self.step(observation) for observation in observations]
|
plancraft/models/oracle.py
CHANGED
@@ -38,3 +38,12 @@ class OracleModel(PlancraftBaseModel):
|
|
38
38
|
|
39
39
|
action = self.subplans.pop(0)
|
40
40
|
return action
|
41
|
+
|
42
|
+
def batch_step(self, observations: list[dict], **kwargs) -> list:
|
43
|
+
# Need to fully isolate state between examples
|
44
|
+
actions = []
|
45
|
+
for observation in observations:
|
46
|
+
self.reset()
|
47
|
+
action = self.step(observation)
|
48
|
+
actions.append(action)
|
49
|
+
return actions
|
@@ -1,6 +1,6 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=SXg7UoZ9KqjxgLkBo53BMw9XdR2nwMx7mDEXVwVhBiM,16544
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
5
|
plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
@@ -1915,12 +1915,12 @@ plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,
|
|
1915
1915
|
plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
|
1916
1916
|
plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
|
1917
1917
|
plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
|
1918
|
-
plancraft/models/dummy.py,sha256=
|
1918
|
+
plancraft/models/dummy.py,sha256=3Nsnw12s_n5mWMuxUTaPCuJIzPp0vLHWKE827iKY5o0,1391
|
1919
1919
|
plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
|
1920
|
-
plancraft/models/oracle.py,sha256=
|
1920
|
+
plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.19.dist-info/METADATA,sha256=mA9CYITxxP37FICRAeVEq9j1WyBEeaWnVAPJsQnA1Ho,11148
|
1924
|
+
plancraft-0.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.19.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|