plancraft 0.3.17__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +134 -3
- plancraft/models/dummy.py +3 -0
- plancraft/models/oracle.py +9 -0
- {plancraft-0.3.17.dist-info → plancraft-0.3.18.dist-info}/METADATA +1 -1
- {plancraft-0.3.17.dist-info → plancraft-0.3.18.dist-info}/RECORD +7 -7
- {plancraft-0.3.17.dist-info → plancraft-0.3.18.dist-info}/WHEEL +0 -0
- {plancraft-0.3.17.dist-info → plancraft-0.3.18.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -176,12 +176,14 @@ class Evaluator:
|
|
176
176
|
example: PlancraftExample,
|
177
177
|
model: PlancraftBaseModel,
|
178
178
|
) -> dict:
|
179
|
-
"""
|
180
|
-
|
179
|
+
"""
|
180
|
+
Given the loaded model and an example from Plancraft
|
181
|
+
run the episode until success or termination.
|
182
|
+
"""
|
181
183
|
|
182
184
|
# start environment
|
183
185
|
environment = PlancraftEnvironment(
|
184
|
-
inventory=example.slotted_inventory,
|
186
|
+
inventory=deepcopy(example.slotted_inventory),
|
185
187
|
resolution=self.resolution,
|
186
188
|
)
|
187
189
|
|
@@ -252,6 +254,135 @@ class Evaluator:
|
|
252
254
|
"images": history.images,
|
253
255
|
}
|
254
256
|
|
257
|
+
def batch_eval_examples(
|
258
|
+
self,
|
259
|
+
examples: list[PlancraftExample],
|
260
|
+
model,
|
261
|
+
) -> list:
|
262
|
+
# Initialize environments and histories
|
263
|
+
environments = [
|
264
|
+
PlancraftEnvironment(
|
265
|
+
inventory=deepcopy(examples[i].slotted_inventory),
|
266
|
+
resolution=self.resolution,
|
267
|
+
)
|
268
|
+
for i in range(len(examples))
|
269
|
+
]
|
270
|
+
|
271
|
+
histories = [
|
272
|
+
History(
|
273
|
+
actions=self.actions,
|
274
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
275
|
+
use_images=self.use_images,
|
276
|
+
use_text_inventory=self.use_text_inventory,
|
277
|
+
resolution=self.resolution,
|
278
|
+
few_shot=self.few_shot,
|
279
|
+
system_prompt=deepcopy(self.system_prompt),
|
280
|
+
prompt_examples=deepcopy(self.prompt_examples),
|
281
|
+
prompt_images=deepcopy(self.prompt_images),
|
282
|
+
)
|
283
|
+
for _ in range(len(examples))
|
284
|
+
]
|
285
|
+
|
286
|
+
# Track which environments are still active
|
287
|
+
active_mask = [True for _ in range(len(examples))]
|
288
|
+
results = [None for _ in range(len(examples))]
|
289
|
+
steps_taken = [0 for _ in range(len(examples))]
|
290
|
+
actions = [None for _ in range(len(examples))]
|
291
|
+
|
292
|
+
while any(active_mask) and all(steps < self.max_steps for steps in steps_taken):
|
293
|
+
# Get observations for all active environments
|
294
|
+
observations = []
|
295
|
+
active_indices = []
|
296
|
+
for i, (env, action, active) in enumerate(
|
297
|
+
zip(environments, actions, active_mask)
|
298
|
+
):
|
299
|
+
if not active:
|
300
|
+
continue
|
301
|
+
|
302
|
+
if isinstance(action, StopAction):
|
303
|
+
# Handle stop action
|
304
|
+
active_mask[i] = False
|
305
|
+
results[i] = {
|
306
|
+
"success": examples[i].impossible,
|
307
|
+
"recipe_type": examples[i].recipe_type,
|
308
|
+
"complexity": examples[i].complexity_split,
|
309
|
+
"number_of_steps": steps_taken[i],
|
310
|
+
"model_trace": histories[i].trace(),
|
311
|
+
"example_id": examples[i].id,
|
312
|
+
"images": histories[i].images,
|
313
|
+
}
|
314
|
+
logger.info("STOP")
|
315
|
+
continue
|
316
|
+
|
317
|
+
active_indices.append(i)
|
318
|
+
if isinstance(action, str):
|
319
|
+
# Handle message action
|
320
|
+
obs = env.step()
|
321
|
+
obs["target"] = examples[i].target
|
322
|
+
obs["message"] = action
|
323
|
+
else:
|
324
|
+
# Handle environment action
|
325
|
+
obs = env.step(action)
|
326
|
+
obs["target"] = examples[i].target
|
327
|
+
obs["message"] = self.convert_observation_to_message(
|
328
|
+
obs, model=model
|
329
|
+
)
|
330
|
+
|
331
|
+
# Check if done
|
332
|
+
if self.check_done(obs["inventory"], examples[i].target):
|
333
|
+
active_mask[i] = False
|
334
|
+
results[i] = {
|
335
|
+
"success": True,
|
336
|
+
"recipe_type": examples[i].recipe_type,
|
337
|
+
"complexity": examples[i].complexity_split,
|
338
|
+
"number_of_steps": steps_taken[i],
|
339
|
+
"model_trace": histories[i].trace(),
|
340
|
+
"example_id": examples[i].id,
|
341
|
+
"images": histories[i].images,
|
342
|
+
}
|
343
|
+
continue
|
344
|
+
|
345
|
+
observations.append(obs)
|
346
|
+
histories[i].add_observation_to_history(obs)
|
347
|
+
histories[i].add_message_to_history(content=obs["message"], role="user")
|
348
|
+
steps_taken[i] += 1
|
349
|
+
|
350
|
+
if not observations:
|
351
|
+
break
|
352
|
+
|
353
|
+
# Batch predict actions for active environments
|
354
|
+
active_histories = [histories[i] for i in active_indices]
|
355
|
+
raw_actions = model.batch_step(
|
356
|
+
observations, dialogue_histories=active_histories
|
357
|
+
)
|
358
|
+
|
359
|
+
# Process actions for each active environment
|
360
|
+
for idx, raw_action in zip(active_indices, raw_actions):
|
361
|
+
logger.info(f"{histories[idx].num_steps}, {raw_action}")
|
362
|
+
histories[idx].add_message_to_history(
|
363
|
+
content=raw_action, role="assistant"
|
364
|
+
)
|
365
|
+
actions[idx] = self.parse_raw_model_response(
|
366
|
+
raw_action,
|
367
|
+
observation=observations[active_indices.index(idx)],
|
368
|
+
history=histories[idx],
|
369
|
+
)
|
370
|
+
|
371
|
+
# Fill in results for environments that didn't finish
|
372
|
+
for i, result in enumerate(results):
|
373
|
+
if result is None:
|
374
|
+
results[i] = {
|
375
|
+
"success": False,
|
376
|
+
"recipe_type": examples[i].recipe_type,
|
377
|
+
"complexity": examples[i].complexity_split,
|
378
|
+
"number_of_steps": steps_taken[i],
|
379
|
+
"model_trace": histories[i].trace(),
|
380
|
+
"example_id": examples[i].id,
|
381
|
+
"images": histories[i].images,
|
382
|
+
}
|
383
|
+
|
384
|
+
return results
|
385
|
+
|
255
386
|
def eval_all_examples(self, model, progress_bar=False) -> list:
|
256
387
|
results = []
|
257
388
|
pbar = tqdm(
|
plancraft/models/dummy.py
CHANGED
@@ -40,3 +40,6 @@ class DummyModel(PlancraftBaseModel):
|
|
40
40
|
|
41
41
|
def step(self, observation: dict, **kwargs) -> str:
|
42
42
|
return str(self.random_select(observation))
|
43
|
+
|
44
|
+
def batch_step(self, observations: list[dict], **kwargs) -> list:
|
45
|
+
return [self.step(observation) for observation in observations]
|
plancraft/models/oracle.py
CHANGED
@@ -38,3 +38,12 @@ class OracleModel(PlancraftBaseModel):
|
|
38
38
|
|
39
39
|
action = self.subplans.pop(0)
|
40
40
|
return action
|
41
|
+
|
42
|
+
def batch_step(self, observations: list[dict], **kwargs) -> list:
|
43
|
+
# Need to fully isolate state between examples
|
44
|
+
actions = []
|
45
|
+
for observation in observations:
|
46
|
+
self.reset()
|
47
|
+
action = self.step(observation)
|
48
|
+
actions.append(action)
|
49
|
+
return actions
|
@@ -1,6 +1,6 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=v8itX8buduqTZdR39gtLwdhKGEnSX3rJv9Yd13EzNgQ,16395
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
5
|
plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
@@ -1915,12 +1915,12 @@ plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,
|
|
1915
1915
|
plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
|
1916
1916
|
plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
|
1917
1917
|
plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
|
1918
|
-
plancraft/models/dummy.py,sha256=
|
1918
|
+
plancraft/models/dummy.py,sha256=3Nsnw12s_n5mWMuxUTaPCuJIzPp0vLHWKE827iKY5o0,1391
|
1919
1919
|
plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
|
1920
|
-
plancraft/models/oracle.py,sha256=
|
1920
|
+
plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.18.dist-info/METADATA,sha256=p_Ln_3jx77ygBZG6yjuLhVs883PysUXUCi1sK67QvJs,11148
|
1924
|
+
plancraft-0.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.18.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|