plancraft 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plancraft/evaluator.py CHANGED
@@ -4,7 +4,6 @@ from typing import Optional
4
4
  from copy import deepcopy
5
5
 
6
6
  import imageio
7
- from loguru import logger
8
7
  from tqdm import tqdm
9
8
 
10
9
  import wandb
@@ -176,12 +175,14 @@ class Evaluator:
176
175
  example: PlancraftExample,
177
176
  model: PlancraftBaseModel,
178
177
  ) -> dict:
179
- """Given the loaded model and an example from Plancraft
180
- run the episode until success or termination."""
178
+ """
179
+ Given the loaded model and an example from Plancraft
180
+ run the episode until success or termination.
181
+ """
181
182
 
182
183
  # start environment
183
184
  environment = PlancraftEnvironment(
184
- inventory=example.slotted_inventory,
185
+ inventory=deepcopy(example.slotted_inventory),
185
186
  resolution=self.resolution,
186
187
  )
187
188
 
@@ -252,6 +253,140 @@ class Evaluator:
252
253
  "images": history.images,
253
254
  }
254
255
 
256
+ def batch_eval_examples(
257
+ self,
258
+ examples: list[PlancraftExample],
259
+ model,
260
+ ) -> list:
261
+ # Initialize environments and histories
262
+ environments = [
263
+ PlancraftEnvironment(
264
+ inventory=deepcopy(examples[i].slotted_inventory),
265
+ resolution=self.resolution,
266
+ )
267
+ for i in range(len(examples))
268
+ ]
269
+
270
+ histories = [
271
+ History(
272
+ actions=self.actions,
273
+ use_multimodal_content_format=self.use_multimodal_content_format,
274
+ use_images=self.use_images,
275
+ use_text_inventory=self.use_text_inventory,
276
+ resolution=self.resolution,
277
+ few_shot=self.few_shot,
278
+ system_prompt=deepcopy(self.system_prompt),
279
+ prompt_examples=deepcopy(self.prompt_examples),
280
+ prompt_images=deepcopy(self.prompt_images),
281
+ )
282
+ for _ in range(len(examples))
283
+ ]
284
+
285
+ # Track which environments are still active
286
+ active_mask = [True for _ in range(len(examples))]
287
+ results = [None for _ in range(len(examples))]
288
+ steps_taken = [0 for _ in range(len(examples))]
289
+ actions = [None for _ in range(len(examples))]
290
+
291
+ while any(active_mask) and all(steps < self.max_steps for steps in steps_taken):
292
+ # Get observations for all active environments
293
+ observations = []
294
+ active_indices = []
295
+ obs_mapping = {} # Map active_indices to their observation index
296
+
297
+ for i, (env, action, active) in enumerate(
298
+ zip(environments, actions, active_mask)
299
+ ):
300
+ if not active:
301
+ continue
302
+
303
+ if isinstance(action, StopAction):
304
+ # Handle stop action
305
+ active_mask[i] = False
306
+ results[i] = {
307
+ "success": examples[i].impossible,
308
+ "recipe_type": examples[i].recipe_type,
309
+ "complexity": examples[i].complexity_split,
310
+ "number_of_steps": steps_taken[i],
311
+ "model_trace": histories[i].trace(),
312
+ "example_id": examples[i].id,
313
+ "images": histories[i].images,
314
+ }
315
+ continue
316
+
317
+ active_indices.append(i)
318
+ if isinstance(action, str):
319
+ # Handle message action
320
+ obs = env.step()
321
+ obs["target"] = examples[i].target
322
+ obs["message"] = action
323
+ else:
324
+ # Handle environment action
325
+ obs = env.step(action)
326
+ obs["target"] = examples[i].target
327
+ obs["message"] = self.convert_observation_to_message(
328
+ obs, model=model
329
+ )
330
+
331
+ # Check if done
332
+ if self.check_done(obs["inventory"], examples[i].target):
333
+ active_mask[i] = False
334
+ results[i] = {
335
+ "success": True,
336
+ "recipe_type": examples[i].recipe_type,
337
+ "complexity": examples[i].complexity_split,
338
+ "number_of_steps": steps_taken[i],
339
+ "model_trace": histories[i].trace(),
340
+ "example_id": examples[i].id,
341
+ "images": histories[i].images,
342
+ }
343
+ continue
344
+
345
+ obs_mapping[i] = len(
346
+ observations
347
+ ) # Map active index to observation index
348
+ observations.append(obs)
349
+ histories[i].add_observation_to_history(obs)
350
+ histories[i].add_message_to_history(content=obs["message"], role="user")
351
+ steps_taken[i] += 1
352
+
353
+ if not observations:
354
+ break
355
+
356
+ # Batch predict actions for active environments
357
+ active_histories = [histories[i] for i in active_indices]
358
+ raw_actions = model.batch_step(
359
+ observations, dialogue_histories=active_histories
360
+ )
361
+
362
+ # Process actions for each active environment
363
+ for idx, raw_action in zip(active_indices, raw_actions):
364
+ histories[idx].add_message_to_history(
365
+ content=raw_action, role="assistant"
366
+ )
367
+ actions[idx] = self.parse_raw_model_response(
368
+ raw_action,
369
+ observation=observations[
370
+ obs_mapping[idx]
371
+ ], # Use mapping to get correct observation
372
+ history=histories[idx],
373
+ )
374
+
375
+ # Fill in results for environments that didn't finish
376
+ for i, result in enumerate(results):
377
+ if result is None:
378
+ results[i] = {
379
+ "success": False,
380
+ "recipe_type": examples[i].recipe_type,
381
+ "complexity": examples[i].complexity_split,
382
+ "number_of_steps": steps_taken[i],
383
+ "model_trace": histories[i].trace(),
384
+ "example_id": examples[i].id,
385
+ "images": histories[i].images,
386
+ }
387
+
388
+ return results
389
+
255
390
  def eval_all_examples(self, model, progress_bar=False) -> list:
256
391
  results = []
257
392
  pbar = tqdm(
plancraft/models/dummy.py CHANGED
@@ -40,3 +40,6 @@ class DummyModel(PlancraftBaseModel):
40
40
 
41
41
  def step(self, observation: dict, **kwargs) -> str:
42
42
  return str(self.random_select(observation))
43
+
44
+ def batch_step(self, observations: list[dict], **kwargs) -> list:
45
+ return [self.step(observation) for observation in observations]
@@ -38,3 +38,12 @@ class OracleModel(PlancraftBaseModel):
38
38
 
39
39
  action = self.subplans.pop(0)
40
40
  return action
41
+
42
+ def batch_step(self, observations: list[dict], **kwargs) -> list:
43
+ # Need to fully isolate state between examples
44
+ actions = []
45
+ for observation in observations:
46
+ self.reset()
47
+ action = self.step(observation)
48
+ actions.append(action)
49
+ return actions
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plancraft
3
- Version: 0.3.17
3
+ Version: 0.3.19
4
4
  Summary: Plancraft: an evaluation dataset for planning with LLM agents
5
5
  License: MIT License
6
6
 
@@ -1,6 +1,6 @@
1
1
  plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
3
- plancraft/evaluator.py,sha256=q7khX8FrMeb5QOgYZba-24jC7ZXp83VU7sa1H1kKS08,11061
3
+ plancraft/evaluator.py,sha256=SXg7UoZ9KqjxgLkBo53BMw9XdR2nwMx7mDEXVwVhBiM,16544
4
4
  plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
5
5
  plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
6
6
  plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
@@ -1915,12 +1915,12 @@ plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,
1915
1915
  plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
1916
1916
  plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
1917
1917
  plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
1918
- plancraft/models/dummy.py,sha256=856oEX6NquXSIIfQLTEFFeB8ib7VUUs5cB0TVHAiFvI,1248
1918
+ plancraft/models/dummy.py,sha256=3Nsnw12s_n5mWMuxUTaPCuJIzPp0vLHWKE827iKY5o0,1391
1919
1919
  plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
1920
- plancraft/models/oracle.py,sha256=tMp9mTwD70T3qohj-LZhJFjHYWyiVHDh8gu27asVimI,1342
1920
+ plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
1921
1921
  plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
1922
1922
  plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
1923
- plancraft-0.3.17.dist-info/METADATA,sha256=DhUcHfnj_fMJTnHQVQVl50RpM3mwPCdFHOFNQtCo39c,11148
1924
- plancraft-0.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
- plancraft-0.3.17.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
- plancraft-0.3.17.dist-info/RECORD,,
1923
+ plancraft-0.3.19.dist-info/METADATA,sha256=mA9CYITxxP37FICRAeVEq9j1WyBEeaWnVAPJsQnA1Ho,11148
1924
+ plancraft-0.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
+ plancraft-0.3.19.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
+ plancraft-0.3.19.dist-info/RECORD,,