plancraft 0.3.17__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plancraft/evaluator.py CHANGED
@@ -176,12 +176,14 @@ class Evaluator:
176
176
  example: PlancraftExample,
177
177
  model: PlancraftBaseModel,
178
178
  ) -> dict:
179
- """Given the loaded model and an example from Plancraft
180
- run the episode until success or termination."""
179
+ """
180
+ Given the loaded model and an example from Plancraft
181
+ run the episode until success or termination.
182
+ """
181
183
 
182
184
  # start environment
183
185
  environment = PlancraftEnvironment(
184
- inventory=example.slotted_inventory,
186
+ inventory=deepcopy(example.slotted_inventory),
185
187
  resolution=self.resolution,
186
188
  )
187
189
 
@@ -252,6 +254,135 @@ class Evaluator:
252
254
  "images": history.images,
253
255
  }
254
256
 
257
+ def batch_eval_examples(
258
+ self,
259
+ examples: list[PlancraftExample],
260
+ model,
261
+ ) -> list:
262
+ # Initialize environments and histories
263
+ environments = [
264
+ PlancraftEnvironment(
265
+ inventory=deepcopy(examples[i].slotted_inventory),
266
+ resolution=self.resolution,
267
+ )
268
+ for i in range(len(examples))
269
+ ]
270
+
271
+ histories = [
272
+ History(
273
+ actions=self.actions,
274
+ use_multimodal_content_format=self.use_multimodal_content_format,
275
+ use_images=self.use_images,
276
+ use_text_inventory=self.use_text_inventory,
277
+ resolution=self.resolution,
278
+ few_shot=self.few_shot,
279
+ system_prompt=deepcopy(self.system_prompt),
280
+ prompt_examples=deepcopy(self.prompt_examples),
281
+ prompt_images=deepcopy(self.prompt_images),
282
+ )
283
+ for _ in range(len(examples))
284
+ ]
285
+
286
+ # Track which environments are still active
287
+ active_mask = [True for _ in range(len(examples))]
288
+ results = [None for _ in range(len(examples))]
289
+ steps_taken = [0 for _ in range(len(examples))]
290
+ actions = [None for _ in range(len(examples))]
291
+
292
+ while any(active_mask) and all(steps < self.max_steps for steps in steps_taken):
293
+ # Get observations for all active environments
294
+ observations = []
295
+ active_indices = []
296
+ for i, (env, action, active) in enumerate(
297
+ zip(environments, actions, active_mask)
298
+ ):
299
+ if not active:
300
+ continue
301
+
302
+ if isinstance(action, StopAction):
303
+ # Handle stop action
304
+ active_mask[i] = False
305
+ results[i] = {
306
+ "success": examples[i].impossible,
307
+ "recipe_type": examples[i].recipe_type,
308
+ "complexity": examples[i].complexity_split,
309
+ "number_of_steps": steps_taken[i],
310
+ "model_trace": histories[i].trace(),
311
+ "example_id": examples[i].id,
312
+ "images": histories[i].images,
313
+ }
314
+ logger.info("STOP")
315
+ continue
316
+
317
+ active_indices.append(i)
318
+ if isinstance(action, str):
319
+ # Handle message action
320
+ obs = env.step()
321
+ obs["target"] = examples[i].target
322
+ obs["message"] = action
323
+ else:
324
+ # Handle environment action
325
+ obs = env.step(action)
326
+ obs["target"] = examples[i].target
327
+ obs["message"] = self.convert_observation_to_message(
328
+ obs, model=model
329
+ )
330
+
331
+ # Check if done
332
+ if self.check_done(obs["inventory"], examples[i].target):
333
+ active_mask[i] = False
334
+ results[i] = {
335
+ "success": True,
336
+ "recipe_type": examples[i].recipe_type,
337
+ "complexity": examples[i].complexity_split,
338
+ "number_of_steps": steps_taken[i],
339
+ "model_trace": histories[i].trace(),
340
+ "example_id": examples[i].id,
341
+ "images": histories[i].images,
342
+ }
343
+ continue
344
+
345
+ observations.append(obs)
346
+ histories[i].add_observation_to_history(obs)
347
+ histories[i].add_message_to_history(content=obs["message"], role="user")
348
+ steps_taken[i] += 1
349
+
350
+ if not observations:
351
+ break
352
+
353
+ # Batch predict actions for active environments
354
+ active_histories = [histories[i] for i in active_indices]
355
+ raw_actions = model.batch_step(
356
+ observations, dialogue_histories=active_histories
357
+ )
358
+
359
+ # Process actions for each active environment
360
+ for idx, raw_action in zip(active_indices, raw_actions):
361
+ logger.info(f"{histories[idx].num_steps}, {raw_action}")
362
+ histories[idx].add_message_to_history(
363
+ content=raw_action, role="assistant"
364
+ )
365
+ actions[idx] = self.parse_raw_model_response(
366
+ raw_action,
367
+ observation=observations[active_indices.index(idx)],
368
+ history=histories[idx],
369
+ )
370
+
371
+ # Fill in results for environments that didn't finish
372
+ for i, result in enumerate(results):
373
+ if result is None:
374
+ results[i] = {
375
+ "success": False,
376
+ "recipe_type": examples[i].recipe_type,
377
+ "complexity": examples[i].complexity_split,
378
+ "number_of_steps": steps_taken[i],
379
+ "model_trace": histories[i].trace(),
380
+ "example_id": examples[i].id,
381
+ "images": histories[i].images,
382
+ }
383
+
384
+ return results
385
+
255
386
  def eval_all_examples(self, model, progress_bar=False) -> list:
256
387
  results = []
257
388
  pbar = tqdm(
plancraft/models/dummy.py CHANGED
@@ -40,3 +40,6 @@ class DummyModel(PlancraftBaseModel):
40
40
 
41
41
  def step(self, observation: dict, **kwargs) -> str:
42
42
  return str(self.random_select(observation))
43
+
44
+ def batch_step(self, observations: list[dict], **kwargs) -> list:
45
+ return [self.step(observation) for observation in observations]
@@ -38,3 +38,12 @@ class OracleModel(PlancraftBaseModel):
38
38
 
39
39
  action = self.subplans.pop(0)
40
40
  return action
41
+
42
+ def batch_step(self, observations: list[dict], **kwargs) -> list:
43
+ # Need to fully isolate state between examples
44
+ actions = []
45
+ for observation in observations:
46
+ self.reset()
47
+ action = self.step(observation)
48
+ actions.append(action)
49
+ return actions
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plancraft
3
- Version: 0.3.17
3
+ Version: 0.3.18
4
4
  Summary: Plancraft: an evaluation dataset for planning with LLM agents
5
5
  License: MIT License
6
6
 
@@ -1,6 +1,6 @@
1
1
  plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
3
- plancraft/evaluator.py,sha256=q7khX8FrMeb5QOgYZba-24jC7ZXp83VU7sa1H1kKS08,11061
3
+ plancraft/evaluator.py,sha256=v8itX8buduqTZdR39gtLwdhKGEnSX3rJv9Yd13EzNgQ,16395
4
4
  plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
5
5
  plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
6
6
  plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
@@ -1915,12 +1915,12 @@ plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,
1915
1915
  plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
1916
1916
  plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
1917
1917
  plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
1918
- plancraft/models/dummy.py,sha256=856oEX6NquXSIIfQLTEFFeB8ib7VUUs5cB0TVHAiFvI,1248
1918
+ plancraft/models/dummy.py,sha256=3Nsnw12s_n5mWMuxUTaPCuJIzPp0vLHWKE827iKY5o0,1391
1919
1919
  plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
1920
- plancraft/models/oracle.py,sha256=tMp9mTwD70T3qohj-LZhJFjHYWyiVHDh8gu27asVimI,1342
1920
+ plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
1921
1921
  plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
1922
1922
  plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
1923
- plancraft-0.3.17.dist-info/METADATA,sha256=DhUcHfnj_fMJTnHQVQVl50RpM3mwPCdFHOFNQtCo39c,11148
1924
- plancraft-0.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
- plancraft-0.3.17.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
- plancraft-0.3.17.dist-info/RECORD,,
1923
+ plancraft-0.3.18.dist-info/METADATA,sha256=p_Ln_3jx77ygBZG6yjuLhVs883PysUXUCi1sK67QvJs,11148
1924
+ plancraft-0.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
+ plancraft-0.3.18.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
+ plancraft-0.3.18.dist-info/RECORD,,