plancraft 0.3.27__py3-none-any.whl → 0.3.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plancraft/evaluator.py CHANGED
@@ -187,47 +187,24 @@ class Evaluator:
187
187
 
188
188
  # initialise history/dialogue tracking
189
189
  history = self.create_history()
190
+ observation = environment.step()
191
+ # add target and first message to history
192
+ observation["target"] = example.target
193
+ observation["message"] = self.convert_observation_to_message(
194
+ observation, model=model
195
+ )
190
196
 
191
197
  success = False
192
- action = None
193
-
194
198
  # run episode until stuck or until max steps is reached
195
199
  while history.num_steps < self.max_steps:
196
- # if the action is stop then we end the episode
197
- if isinstance(action, StopAction):
198
- # if the action is stop and task is impossible then success
199
- # otherwise we should not have stopped
200
- success = example.impossible
201
- break
202
- # action is external tool then it is str
203
- if isinstance(action, str):
204
- observation = environment.step()
205
- observation["target"] = example.target
206
- observation["message"] = action
207
- # action is environment action
208
- else:
209
- observation = environment.step(action)
210
- # convert inventory observation to text message
211
- observation["target"] = example.target
212
- observation["message"] = self.convert_observation_to_message(
213
- observation, model=model
214
- )
215
- # check if the episode is done
216
- success = self.check_done(observation["inventory"], example.target)
217
- # exit if success
218
- if success:
219
- break
220
-
221
200
  # add observation to history
222
201
  history.add_observation_to_history(observation)
223
- # add observation message to history
224
202
  history.add_message_to_history(content=observation["message"], role="user")
225
203
  # predict next action
226
204
  raw_action = model.step(observation, dialogue_history=history)
227
205
 
228
206
  # if the model returns a PlancraftModelOutput, extract the action
229
207
  if isinstance(raw_action, PlancraftModelOutput):
230
- # add message to history
231
208
  history.add_message_to_history(
232
209
  content=raw_action.action,
233
210
  role="assistant",
@@ -235,7 +212,6 @@ class Evaluator:
235
212
  )
236
213
  raw_action = raw_action.action
237
214
  elif isinstance(raw_action, str):
238
- # add message to history
239
215
  history.add_message_to_history(content=raw_action, role="assistant")
240
216
  else:
241
217
  raise ValueError(
@@ -247,6 +223,41 @@ class Evaluator:
247
223
  raw_action, observation=observation, history=history
248
224
  )
249
225
 
226
+ # if the action is stop then we end the episode
227
+ if isinstance(action, StopAction):
228
+ # if the action is stop and task is impossible then success
229
+ # otherwise we should not have stopped
230
+ observation = None
231
+ success = example.impossible
232
+ # action is external tool then it is str
233
+ elif isinstance(action, str):
234
+ observation = environment.step()
235
+ observation["target"] = example.target
236
+ observation["message"] = action
237
+ # action is environment action
238
+ else:
239
+ observation = environment.step(action)
240
+ observation["target"] = example.target
241
+ observation["message"] = self.convert_observation_to_message(
242
+ observation, model=model
243
+ )
244
+ # check if the episode is done
245
+ success = self.check_done(observation["inventory"], example.target)
246
+
247
+ # update model with success or failure
248
+ # observation is the next state after the action (s1)
249
+ # history is the dialogue history
250
+ # -- the last message contains the action taken (a0)
251
+ # -- the second to last message is the observation (s0)
252
+ # success is whether the episode is sucessful (r)
253
+ model.update(
254
+ observation=observation, history=history, success=success, action=action
255
+ )
256
+
257
+ # exit if success
258
+ if success or isinstance(action, StopAction):
259
+ break
260
+
250
261
  # save results and reset
251
262
  return {
252
263
  "success": success,
@@ -263,6 +274,13 @@ class Evaluator:
263
274
  examples: list[PlancraftExample],
264
275
  model,
265
276
  ) -> list:
277
+ """
278
+ Similar to eval_example, but processes multiple examples at once.
279
+
280
+ Tracks which environments are still active until they've either succeeded,
281
+ reached max steps, or invoked StopAction.
282
+ """
283
+
266
284
  # Initialize environments and histories
267
285
  environments = [
268
286
  PlancraftEnvironment(
@@ -271,126 +289,139 @@ class Evaluator:
271
289
  )
272
290
  for i in range(len(examples))
273
291
  ]
274
-
275
292
  histories = [self.create_history() for _ in range(len(examples))]
276
293
 
277
294
  # Track which environments are still active
278
295
  active_mask = [True for _ in range(len(examples))]
279
296
  results = [None for _ in range(len(examples))]
280
- steps_taken = [0 for _ in range(len(examples))]
281
- actions = [None for _ in range(len(examples))]
282
-
283
- while any(active_mask) and all(steps < self.max_steps for steps in steps_taken):
284
- # Get observations for all active environments
285
- observations = []
286
- active_indices = []
287
- active_histories = []
288
-
289
- for i, (env, action, active) in enumerate(
290
- zip(environments, actions, active_mask)
291
- ):
292
- if not active:
293
- continue
294
-
295
- if isinstance(action, StopAction):
296
- # Handle stop action
297
- active_mask[i] = False
298
- results[i] = {
299
- "success": examples[i].impossible,
300
- "recipe_type": examples[i].recipe_type,
301
- "complexity": examples[i].complexity_split,
302
- "number_of_steps": steps_taken[i],
303
- "model_trace": histories[i].trace(),
304
- "example_id": examples[i].id,
305
- "images": histories[i].images,
306
- }
307
- continue
297
+ observations = []
298
+
299
+ # Initialize observations (s0) and user messages from environment
300
+ for i in range(len(examples)):
301
+ obs = environments[i].step()
302
+ obs["target"] = examples[i].target
303
+ obs["message"] = self.convert_observation_to_message(obs, model=model)
304
+ observations.append(obs)
305
+
306
+ # Process until all done or max steps reached
307
+ while any(active_mask) and all(
308
+ history.num_steps < self.max_steps for history in histories
309
+ ):
310
+ # Gather active environments
311
+ active_indices = [
312
+ i
313
+ for i, active in enumerate(active_mask)
314
+ if active and histories[i].num_steps < self.max_steps
315
+ ]
316
+ if not active_indices:
317
+ break
308
318
 
309
- # Get observation
310
- if isinstance(action, str):
311
- obs = env.step()
312
- obs["target"] = examples[i].target
313
- obs["message"] = action
314
- else:
315
- obs = env.step(action)
316
- obs["target"] = examples[i].target
317
- obs["message"] = self.convert_observation_to_message(
318
- obs, model=model
319
+ # For each active environment, add new obs to history for next iteration
320
+ for env_idx in active_indices:
321
+ if active_mask[env_idx]:
322
+ histories[env_idx].add_observation_to_history(observations[env_idx])
323
+ histories[env_idx].add_message_to_history(
324
+ content=observations[env_idx]["message"], role="user"
319
325
  )
320
326
 
321
- # Check if done
322
- if self.check_done(obs["inventory"], examples[i].target):
323
- active_mask[i] = False
324
- results[i] = {
325
- "success": True,
326
- "recipe_type": examples[i].recipe_type,
327
- "complexity": examples[i].complexity_split,
328
- "number_of_steps": steps_taken[i],
329
- "model_trace": histories[i].trace(),
330
- "example_id": examples[i].id,
331
- "images": histories[i].images,
332
- }
333
- continue
334
-
335
- # Add to batch lists
336
- active_indices.append(i)
337
- observations.append(obs)
338
- active_histories.append(histories[i])
339
-
340
- # Update history
341
- histories[i].add_observation_to_history(obs)
342
- histories[i].add_message_to_history(content=obs["message"], role="user")
343
- steps_taken[i] += 1
344
-
345
- if not observations:
346
- break
327
+ batch_observations = [observations[i] for i in active_indices]
328
+ batch_histories = [histories[i] for i in active_indices]
347
329
 
348
- # Batch predict actions for active environments
330
+ # Predict next actions in batch
349
331
  raw_actions = model.batch_step(
350
- observations, dialogue_histories=active_histories
332
+ batch_observations, dialogue_histories=batch_histories
351
333
  )
352
334
 
353
- # Process actions for each active environment
354
- for batch_idx, (idx, raw_action) in enumerate(
355
- zip(active_indices, raw_actions)
356
- ):
357
- # if the model returns a PlancraftModelOutput, extract the action
335
+ # Process each raw action and update environment/history
336
+ successes = []
337
+ actions = []
338
+ for env_idx, raw_action in zip(active_indices, raw_actions):
339
+ # Add model's message to history
358
340
  if isinstance(raw_action, PlancraftModelOutput):
359
- # add message to history
360
- histories[idx].add_message_to_history(
341
+ histories[env_idx].add_message_to_history(
361
342
  content=raw_action.action,
362
343
  role="assistant",
363
344
  **(raw_action.kwargs or {}),
364
345
  )
365
- actions[idx] = self.parse_raw_model_response(
366
- raw_action.action,
367
- observation=observations[batch_idx],
368
- history=histories[idx],
369
- )
370
- # if the model returns a string, parse the raw action
346
+ raw_action = raw_action.action
371
347
  elif isinstance(raw_action, str):
372
- # add message to history
373
- histories[idx].add_message_to_history(
348
+ histories[env_idx].add_message_to_history(
374
349
  content=raw_action, role="assistant"
375
350
  )
376
- actions[idx] = self.parse_raw_model_response(
377
- raw_action,
378
- observation=observations[batch_idx],
379
- history=histories[idx],
380
- )
381
351
  else:
382
352
  raise ValueError(
383
- f"model.step() output must be a string or PlancraftModelOutput, got {type(raw_action)}"
353
+ f"model.batch_step() must return list[str] or list[PlancraftModelOutput], got {type(raw_action)}"
384
354
  )
385
355
 
386
- # Fill in results for environments that didn't finish
356
+ # Parse action
357
+ action = self.parse_raw_model_response(
358
+ raw_action,
359
+ observation=observations[env_idx],
360
+ history=histories[env_idx],
361
+ )
362
+ actions.append(action)
363
+ success = False
364
+ # If action is StopAction
365
+ if isinstance(action, StopAction):
366
+ # if the action is StopAction and the example is impossible,
367
+ # we consider that a 'success' in the sense that the model recognized it can't be done
368
+ success = examples[env_idx].impossible
369
+ observations[env_idx] = None
370
+ # If parsed action is a string, it's a message
371
+ elif isinstance(action, str):
372
+ obs = environments[env_idx].step()
373
+ obs["target"] = examples[env_idx].target
374
+ obs["message"] = action
375
+ observations[env_idx] = obs
376
+ # Otherwise it's an actual environment action
377
+ else:
378
+ obs = environments[env_idx].step(action)
379
+ obs["target"] = examples[env_idx].target
380
+ obs["message"] = self.convert_observation_to_message(
381
+ obs, model=model
382
+ )
383
+ observations[env_idx] = obs
384
+ success = self.check_done(
385
+ obs["inventory"], examples[env_idx].target
386
+ )
387
+
388
+ successes.append(success)
389
+
390
+ # If done, or action was stop, mark inactive and store result
391
+ if (
392
+ success
393
+ or isinstance(action, StopAction)
394
+ or histories[env_idx].num_steps >= self.max_steps
395
+ ):
396
+ active_mask[env_idx] = False
397
+ results[env_idx] = {
398
+ "success": success,
399
+ "recipe_type": examples[env_idx].recipe_type,
400
+ "complexity": examples[env_idx].complexity_split,
401
+ "number_of_steps": histories[env_idx].num_steps,
402
+ "model_trace": histories[env_idx].trace(),
403
+ "example_id": examples[env_idx].id,
404
+ "images": histories[env_idx].images,
405
+ }
406
+
407
+ # Update the model for this single environment
408
+ batch_observations = [observations[i] for i in active_indices]
409
+ batch_histories = [histories[i] for i in active_indices]
410
+ model.batch_update(
411
+ observations=batch_observations,
412
+ histories=batch_histories,
413
+ successes=successes,
414
+ actions=actions,
415
+ )
416
+
417
+ # Fill in results for any environment that never completed
387
418
  for i, result in enumerate(results):
388
419
  if result is None:
389
420
  results[i] = {
390
421
  "success": False,
391
422
  "recipe_type": examples[i].recipe_type,
392
423
  "complexity": examples[i].complexity_split,
393
- "number_of_steps": steps_taken[i],
424
+ "number_of_steps": histories[i].num_steps,
394
425
  "model_trace": histories[i].trace(),
395
426
  "example_id": examples[i].id,
396
427
  "images": histories[i].images,
plancraft/models/act.py CHANGED
@@ -72,3 +72,6 @@ class ActModel(PlancraftBaseModel):
72
72
  dialogue_history.tokens_used += action_token_used
73
73
  # return raw action message
74
74
  return action_messages[0].split("\n")[0].strip()
75
+
76
+ def update(self, **kwargs):
77
+ pass
plancraft/models/base.py CHANGED
@@ -33,3 +33,10 @@ class PlancraftBaseModel(abc.ABC):
33
33
  Reset the model state - ready for a new episode
34
34
  """
35
35
  raise NotImplementedError()
36
+
37
+ @abc.abstractmethod
38
+ def update(self, **kwargs) -> None:
39
+ """
40
+ Update the model state based on the dialogue history
41
+ """
42
+ raise NotImplementedError()
plancraft/models/dummy.py CHANGED
@@ -45,3 +45,9 @@ class DummyModel(PlancraftBaseModel):
45
45
  self, observations: list[dict], **kwargs
46
46
  ) -> list[PlancraftModelOutput]:
47
47
  return [self.step(observation) for observation in observations]
48
+
49
+ def update(self, **kwargs):
50
+ pass
51
+
52
+ def batch_update(self, **kwargs):
53
+ pass
@@ -47,3 +47,9 @@ class OracleModel(PlancraftBaseModel):
47
47
  action = self.step(observation)
48
48
  actions.append(action)
49
49
  return actions
50
+
51
+ def update(self, **kwargs):
52
+ pass
53
+
54
+ def batch_update(self, **kwargs):
55
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plancraft
3
- Version: 0.3.27
3
+ Version: 0.3.29
4
4
  Summary: Plancraft: an evaluation dataset for planning with LLM agents
5
5
  License: MIT License
6
6
 
@@ -1,6 +1,6 @@
1
1
  plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
3
- plancraft/evaluator.py,sha256=dyszVJtTc_PThVEeGmp6YMkmEn4gaXQW52eWaKO2FQ8,17210
3
+ plancraft/evaluator.py,sha256=v73k0O8mTUj87jC1ODL9w59IzBOoOJUfmYaB2x1s73U,18850
4
4
  plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
5
5
  plancraft/utils.py,sha256=VhnxMihh6pRhNjQTK5HDc0FYWmF9_EcQyRP_a7fbIZA,7156
6
6
  plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
@@ -1912,15 +1912,15 @@ plancraft/environment/tags/wooden_stairs.json,sha256=GCr2_5UGPMYZECqQ_5NYSvbwuwt
1912
1912
  plancraft/environment/tags/wooden_trapdoors.json,sha256=DbjfwoHJL8VuYWV61A1uDqW7LJsGlOP4eoxcGIQVYr4,303
1913
1913
  plancraft/environment/tags/wool.json,sha256=Z59l4mdPztVZBFaglJ4mV9H2OnyCVzhqQRi2dduak78,496
1914
1914
  plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,510
1915
- plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
1916
- plancraft/models/base.py,sha256=S8EdkqWpn8nE1WcrqDoA4Hx4p52qEttGxnqjIPWvl3Q,852
1915
+ plancraft/models/act.py,sha256=_OZo9a_6R0wajdR7axZarjI3IJP7glFrWeDIrbcHDmw,2737
1916
+ plancraft/models/base.py,sha256=Krm6MdOjU-qlps1WSX7pxdnqXLiyI3qsI9Na7Xk8r1c,1038
1917
1917
  plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
1918
- plancraft/models/dummy.py,sha256=_NUTviv5ye6KGzODRt0Zykk8shsek0QBqWCeZW3ldSQ,1495
1918
+ plancraft/models/dummy.py,sha256=UWbW3bjrQr_0UYYrNf_D0jWpUq6e50vAp21F0zi8iFM,1593
1919
1919
  plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
1920
- plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
1920
+ plancraft/models/oracle.py,sha256=jmt_kBBNXt0VWUX7q6OHkJoRZWItCMy4qGH5qbLSc1c,1755
1921
1921
  plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
1922
1922
  plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
1923
- plancraft-0.3.27.dist-info/METADATA,sha256=fii20vfc62_UjIquU8OXEWtxRXfaaaA2lRo7EbFaQok,11148
1924
- plancraft-0.3.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
- plancraft-0.3.27.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
- plancraft-0.3.27.dist-info/RECORD,,
1923
+ plancraft-0.3.29.dist-info/METADATA,sha256=qLWNDUZpsYGVEGNvHwRwxf912NjcVihvN_5oTvyMG5c,11148
1924
+ plancraft-0.3.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
+ plancraft-0.3.29.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
+ plancraft-0.3.29.dist-info/RECORD,,