plancraft 0.3.22__py3-none-any.whl → 0.3.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plancraft/evaluator.py CHANGED
@@ -19,8 +19,8 @@ from plancraft.environment.env import (
19
19
  get_objective_str,
20
20
  target_and_inventory_to_text_obs,
21
21
  )
22
- from plancraft.models.base import PlancraftBaseModel
23
- from plancraft.utils import History
22
+ from plancraft.models.base import PlancraftBaseModel, PlancraftModelOutput
23
+ from plancraft.utils import HistoryBase, History, HistoryConfig
24
24
 
25
25
 
26
26
  class Evaluator:
@@ -41,40 +41,39 @@ class Evaluator:
41
41
  actions: list[ActionHandlerBase] = [MoveActionHandler(), SmeltActionHandler()],
42
42
  output_dir: str = "output",
43
43
  split: str = "val.small",
44
- resolution: str = "high",
45
44
  max_steps: int = 30,
46
45
  resume: bool = False,
46
+ use_fasterrcnn: bool = False,
47
47
  use_multimodal_content_format: bool = False,
48
48
  use_images: bool = False,
49
49
  use_text_inventory: bool = False,
50
- use_fasterrcnn: bool = False,
51
- system_prompt: Optional[dict] = None,
52
- prompt_examples: list[dict] = [],
53
- prompt_images: list[str] = [],
54
- few_shot: bool = True,
50
+ resolution: str = "high",
51
+ history_config: Optional[HistoryConfig] = None,
52
+ history_class: type[HistoryBase] = History,
55
53
  ):
56
54
  self.run_name = run_name
55
+ self.actions = actions
56
+ self.output_dir = f"{output_dir}/{run_name}/{split}"
57
+ self.max_steps = max_steps
58
+ self.resume = resume
59
+ self.use_fasterrcnn = use_fasterrcnn
60
+ self.generation_number = 0
57
61
  self.use_multimodal_content_format = use_multimodal_content_format
58
62
  self.use_images = use_images
59
63
  self.use_text_inventory = use_text_inventory
60
- self.use_fasterrcnn = use_fasterrcnn
61
- self.max_steps = max_steps
62
- self.resume = resume
63
64
  self.resolution = resolution
64
65
 
65
- # history args
66
- self.system_prompt = system_prompt
67
- self.prompt_examples = prompt_examples
68
- self.prompt_images = prompt_images
69
- self.few_shot = few_shot
66
+ # Set up history configuration
67
+ self.history_config = history_config or HistoryConfig()
68
+ self.history_class = history_class
70
69
 
71
- self.output_dir = f"{output_dir}/{run_name}/{split}"
72
- self.generation_number = 0
73
- self.actions = actions
74
-
75
- # load all examples
70
+ # load examples
76
71
  self.examples: list[PlancraftExample] = self.load_dataset(split)
77
72
 
73
+ def create_history(self) -> HistoryBase:
74
+ """Create a new History instance with current configuration"""
75
+ return self.history_class(actions=self.actions, config=self.history_config)
76
+
78
77
  def save_results_dict(self, example: PlancraftExample, results_dict: dict):
79
78
  output_dir = f"{self.output_dir}/{self.generation_number}"
80
79
  os.makedirs(output_dir, exist_ok=True)
@@ -187,17 +186,7 @@ class Evaluator:
187
186
  )
188
187
 
189
188
  # initialise history/dialogue tracking
190
- history = History(
191
- actions=self.actions,
192
- use_multimodal_content_format=self.use_multimodal_content_format,
193
- use_images=self.use_images,
194
- use_text_inventory=self.use_text_inventory,
195
- resolution=self.resolution,
196
- few_shot=self.few_shot,
197
- system_prompt=deepcopy(self.system_prompt),
198
- prompt_examples=deepcopy(self.prompt_examples),
199
- prompt_images=deepcopy(self.prompt_images),
200
- )
189
+ history = self.create_history()
201
190
 
202
191
  success = False
203
192
  action = None
@@ -235,8 +224,24 @@ class Evaluator:
235
224
  history.add_message_to_history(content=observation["message"], role="user")
236
225
  # predict next action
237
226
  raw_action = model.step(observation, dialogue_history=history)
238
- # add message to history
239
- history.add_message_to_history(content=raw_action, role="assistant")
227
+
228
+ # if the model returns a PlancraftModelOutput, extract the action
229
+ if isinstance(raw_action, PlancraftModelOutput):
230
+ # add message to history
231
+ history.add_message_to_history(
232
+ content=raw_action.action,
233
+ role="assistant",
234
+ **(raw_action.kwargs or {}),
235
+ )
236
+ raw_action = raw_action.action
237
+ elif isinstance(raw_action, str):
238
+ # add message to history
239
+ history.add_message_to_history(content=raw_action, role="assistant")
240
+ else:
241
+ raise ValueError(
242
+ f"model.step() output must be a string or PlancraftModelOutput, got {type(raw_action)}"
243
+ )
244
+
240
245
  # parse the raw action
241
246
  action = self.parse_raw_model_response(
242
247
  raw_action, observation=observation, history=history
@@ -267,20 +272,7 @@ class Evaluator:
267
272
  for i in range(len(examples))
268
273
  ]
269
274
 
270
- histories = [
271
- History(
272
- actions=self.actions,
273
- use_multimodal_content_format=self.use_multimodal_content_format,
274
- use_images=self.use_images,
275
- use_text_inventory=self.use_text_inventory,
276
- resolution=self.resolution,
277
- few_shot=self.few_shot,
278
- system_prompt=deepcopy(self.system_prompt),
279
- prompt_examples=deepcopy(self.prompt_examples),
280
- prompt_images=deepcopy(self.prompt_images),
281
- )
282
- for _ in range(len(examples))
283
- ]
275
+ histories = [self.create_history() for _ in range(len(examples))]
284
276
 
285
277
  # Track which environments are still active
286
278
  active_mask = [True for _ in range(len(examples))]
@@ -362,14 +354,34 @@ class Evaluator:
362
354
  for batch_idx, (idx, raw_action) in enumerate(
363
355
  zip(active_indices, raw_actions)
364
356
  ):
365
- histories[idx].add_message_to_history(
366
- content=raw_action, role="assistant"
367
- )
368
- actions[idx] = self.parse_raw_model_response(
369
- raw_action,
370
- observation=observations[batch_idx],
371
- history=histories[idx],
372
- )
357
+ # if the model returns a PlancraftModelOutput, extract the action
358
+ if isinstance(raw_action, PlancraftModelOutput):
359
+ # add message to history
360
+ histories[idx].add_message_to_history(
361
+ content=raw_action.action,
362
+ role="assistant",
363
+ **(raw_action.kwargs or {}),
364
+ )
365
+ actions[idx] = self.parse_raw_model_response(
366
+ raw_action.action,
367
+ observation=observations[batch_idx],
368
+ history=histories[idx],
369
+ )
370
+ # if the model returns a string, parse the raw action
371
+ elif isinstance(raw_action, str):
372
+ # add message to history
373
+ histories[idx].add_message_to_history(
374
+ content=raw_action, role="assistant"
375
+ )
376
+ actions[idx] = self.parse_raw_model_response(
377
+ raw_action,
378
+ observation=observations[batch_idx],
379
+ history=histories[idx],
380
+ )
381
+ else:
382
+ raise ValueError(
383
+ f"model.step() output must be a string or PlancraftModelOutput, got {type(raw_action)}"
384
+ )
373
385
 
374
386
  # Fill in results for environments that didn't finish
375
387
  for i, result in enumerate(results):
plancraft/models/base.py CHANGED
@@ -1,15 +1,26 @@
1
1
  import abc
2
2
 
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
3
6
  from plancraft.utils import History
4
7
 
5
8
 
9
+ @dataclass
10
+ class PlancraftModelOutput:
11
+ action: str
12
+ kwargs: Optional[dict] = None
13
+
14
+
6
15
  class PlancraftBaseModel(abc.ABC):
7
16
  """
8
17
  Model class must implement the following methods to work with evaluator
9
18
  """
10
19
 
11
20
  @abc.abstractmethod
12
- def step(self, observation: dict, dialogue_history: History) -> str:
21
+ def step(
22
+ self, observation: dict, dialogue_history: History
23
+ ) -> PlancraftModelOutput | str:
13
24
  """
14
25
  Model should output an action in text based on the types available
15
26
  We also pass history to the model to allow for chat models to track the dialogue
plancraft/models/dummy.py CHANGED
@@ -3,7 +3,7 @@ import random
3
3
  from plancraft.environment.actions import (
4
4
  MoveAction,
5
5
  )
6
- from plancraft.models.base import PlancraftBaseModel
6
+ from plancraft.models.base import PlancraftBaseModel, PlancraftModelOutput
7
7
 
8
8
 
9
9
  class DummyModel(PlancraftBaseModel):
@@ -38,8 +38,10 @@ class DummyModel(PlancraftBaseModel):
38
38
  slot_from=random_slot_from, slot_to=random_slot_to, quantity=1
39
39
  )
40
40
 
41
- def step(self, observation: dict, **kwargs) -> str:
42
- return str(self.random_select(observation))
41
+ def step(self, observation: dict, **kwargs) -> PlancraftModelOutput:
42
+ return PlancraftModelOutput(action=str(self.random_select(observation)))
43
43
 
44
- def batch_step(self, observations: list[dict], **kwargs) -> list:
44
+ def batch_step(
45
+ self, observations: list[dict], **kwargs
46
+ ) -> list[PlancraftModelOutput]:
45
47
  return [self.step(observation) for observation in observations]
plancraft/utils.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import glob
2
2
  import pathlib
3
- from collections import Counter
4
3
  from copy import copy
5
4
  from typing import Optional
5
+ import abc
6
+ from dataclasses import dataclass, field
6
7
 
7
8
  import torch
8
9
  from loguru import logger
@@ -15,7 +16,56 @@ from plancraft.environment.prompts import (
15
16
  )
16
17
 
17
18
 
18
- class History:
19
+ @dataclass
20
+ class HistoryConfig:
21
+ """Configuration for History instances"""
22
+
23
+ few_shot: bool = True
24
+ system_prompt: Optional[dict] = None
25
+ prompt_examples: list[dict] = field(default_factory=list)
26
+ prompt_images: list[str] = field(default_factory=list)
27
+
28
+
29
+ class HistoryBase(abc.ABC):
30
+ """Abstract base class defining the interface required by the Evaluator"""
31
+
32
+ @property
33
+ @abc.abstractmethod
34
+ def num_steps(self) -> int:
35
+ """Return the number of interaction steps taken"""
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def add_message_to_history(
40
+ self, content: str | dict, role: str = "user", **kwargs
41
+ ) -> None:
42
+ """Add a message to the dialogue history"""
43
+ pass
44
+
45
+ @abc.abstractmethod
46
+ def add_observation_to_history(self, observation: dict, **kwargs) -> None:
47
+ """Add an observation (inventory, image) to history"""
48
+ pass
49
+
50
+ @abc.abstractmethod
51
+ def trace(self) -> dict:
52
+ """Return a traceable history of the interaction"""
53
+ pass
54
+
55
+ @property
56
+ @abc.abstractmethod
57
+ def images(self) -> list:
58
+ """Return list of images"""
59
+ pass
60
+
61
+ @images.setter
62
+ @abc.abstractmethod
63
+ def images(self, value: list) -> None:
64
+ """Set list of images"""
65
+ pass
66
+
67
+
68
+ class History(HistoryBase):
19
69
  """
20
70
  History class to keep track of dialogue, actions, inventory and images
21
71
  Args:
@@ -27,42 +77,40 @@ class History:
27
77
  def __init__(
28
78
  self,
29
79
  actions: list[ActionHandlerBase] = [],
30
- use_multimodal_content_format=False,
31
- few_shot=False,
32
- use_images=False,
33
- use_text_inventory=False,
34
- resolution="high",
35
- system_prompt: Optional[dict] = None,
36
- prompt_examples: list[dict] = [],
37
- prompt_images: list[str] = [],
80
+ config: HistoryConfig = HistoryConfig(),
81
+ resolution: str = "high",
82
+ use_multimodal_content_format: bool = False,
83
+ use_images: bool = False,
84
+ use_text_inventory: bool = True,
38
85
  ):
39
86
  self.action_handlers = actions
40
87
  self.use_multimodal_content_format = use_multimodal_content_format
41
- self.few_shot = few_shot
88
+
42
89
  self.use_images = use_images
43
90
  self.use_text_inventory = use_text_inventory
44
- self.resolution = resolution # low, medium, high
91
+ self.resolution = resolution
45
92
 
46
93
  self.inventory_history = []
47
94
  self.tokens_used = 0
48
95
 
49
96
  # use system prompt if provided
50
- if system_prompt:
51
- self.system_prompt_dialogue = system_prompt
97
+ if config.system_prompt:
98
+ self.system_prompt_dialogue = config.system_prompt
52
99
  else:
53
100
  # generate system prompt
54
101
  self.system_prompt_dialogue = get_system_prompt(
55
102
  handlers=self.action_handlers,
56
103
  use_multimodal_content_format=self.use_multimodal_content_format,
57
104
  )
105
+ self.few_shot = config.few_shot
58
106
 
59
107
  # set up dialogue history with few-shot prompt
60
- self.prompt_examples = prompt_examples
61
- self.prompt_images = prompt_images
108
+ self.prompt_examples = config.prompt_examples
109
+ self.prompt_images = config.prompt_images
62
110
  self.set_up_few_shot_prompt()
63
111
 
64
112
  self.dialogue_history = copy(self.prompt_examples)
65
- self.images = copy(self.prompt_images)
113
+ self._images = copy(self.prompt_images)
66
114
  self.initial_dialogue_length = len(self.dialogue_history)
67
115
 
68
116
  def set_up_few_shot_prompt(self):
@@ -80,7 +128,7 @@ class History:
80
128
  if self.use_images:
81
129
  self.prompt_images = load_prompt_images(resolution=self.resolution)
82
130
 
83
- def add_message_to_history(self, content: str | dict, role="user"):
131
+ def add_message_to_history(self, content: str | dict, role="user", **kwargs):
84
132
  if isinstance(content, dict):
85
133
  assert "content" in content, "content key not found in message"
86
134
  content["role"] = role
@@ -102,9 +150,9 @@ class History:
102
150
  self.inventory_history.append(inventory)
103
151
 
104
152
  def add_image_to_history(self, image):
105
- self.images.append(image)
153
+ self._images.append(image)
106
154
 
107
- def add_observation_to_history(self, observation: dict):
155
+ def add_observation_to_history(self, observation: dict, **kwargs):
108
156
  if observation is None:
109
157
  return
110
158
  if "inventory" in observation:
@@ -118,7 +166,7 @@ class History:
118
166
  def reset(self):
119
167
  # reset dialogue history to few-shot prompt
120
168
  self.dialogue_history = copy(self.prompt_examples)
121
- self.images = copy(self.prompt_images)
169
+ self._images = copy(self.prompt_images)
122
170
  self.initial_dialogue_length = len(self.dialogue_history)
123
171
 
124
172
  self.inventory_history = []
@@ -138,6 +186,14 @@ class History:
138
186
  def num_steps(self):
139
187
  return (len(self.dialogue_history) - self.initial_dialogue_length) // 2
140
188
 
189
+ @property
190
+ def images(self) -> list:
191
+ return self._images
192
+
193
+ @images.setter
194
+ def images(self, value: list) -> None:
195
+ self._images = value
196
+
141
197
 
142
198
  def get_downloaded_models() -> dict:
143
199
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plancraft
3
- Version: 0.3.22
3
+ Version: 0.3.24
4
4
  Summary: Plancraft: an evaluation dataset for planning with LLM agents
5
5
  License: MIT License
6
6
 
@@ -1,8 +1,8 @@
1
1
  plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
3
- plancraft/evaluator.py,sha256=R_RZN9AL_ae0rIvj7HLhYolTpCVMuhPTJfIrmyoLaX4,16326
3
+ plancraft/evaluator.py,sha256=dyszVJtTc_PThVEeGmp6YMkmEn4gaXQW52eWaKO2FQ8,17210
4
4
  plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
5
- plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
5
+ plancraft/utils.py,sha256=VhnxMihh6pRhNjQTK5HDc0FYWmF9_EcQyRP_a7fbIZA,7156
6
6
  plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
7
7
  plancraft/data/test.small.easy.json,sha256=5NZEJ2PqIgmHQecJOIVQyM1D6GFKyJq7GVmgRudaqQk,189304
8
8
  plancraft/data/test.small.json,sha256=eULAG1rdolRMXPrecV-7YoDIheKGyIT5MVpWdISV0wg,270089
@@ -1913,14 +1913,14 @@ plancraft/environment/tags/wooden_trapdoors.json,sha256=DbjfwoHJL8VuYWV61A1uDqW7
1913
1913
  plancraft/environment/tags/wool.json,sha256=Z59l4mdPztVZBFaglJ4mV9H2OnyCVzhqQRi2dduak78,496
1914
1914
  plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,510
1915
1915
  plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
1916
- plancraft/models/base.py,sha256=uhG1tRmsBerJzW8qHoLyLEYpveDv0co7AAhi4mSfyO4,661
1916
+ plancraft/models/base.py,sha256=S8EdkqWpn8nE1WcrqDoA4Hx4p52qEttGxnqjIPWvl3Q,852
1917
1917
  plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
1918
- plancraft/models/dummy.py,sha256=3Nsnw12s_n5mWMuxUTaPCuJIzPp0vLHWKE827iKY5o0,1391
1918
+ plancraft/models/dummy.py,sha256=_NUTviv5ye6KGzODRt0Zykk8shsek0QBqWCeZW3ldSQ,1495
1919
1919
  plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
1920
1920
  plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
1921
1921
  plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
1922
1922
  plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
1923
- plancraft-0.3.22.dist-info/METADATA,sha256=jTX0TZZxJRldUDDFuJ6AhuN1Bf5Jc2DuDooPVwCBkAQ,11148
1924
- plancraft-0.3.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
- plancraft-0.3.22.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
- plancraft-0.3.22.dist-info/RECORD,,
1923
+ plancraft-0.3.24.dist-info/METADATA,sha256=yCmPq3zXC2cEu5E9c1MOT1nsrfGcI_clOMYIV4GQoR4,11148
1924
+ plancraft-0.3.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
1925
+ plancraft-0.3.24.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
1926
+ plancraft-0.3.24.dist-info/RECORD,,