plancraft 0.3.22__py3-none-any.whl → 0.3.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +68 -56
- plancraft/models/base.py +12 -1
- plancraft/models/dummy.py +6 -4
- plancraft/utils.py +77 -21
- {plancraft-0.3.22.dist-info → plancraft-0.3.24.dist-info}/METADATA +1 -1
- {plancraft-0.3.22.dist-info → plancraft-0.3.24.dist-info}/RECORD +8 -8
- {plancraft-0.3.22.dist-info → plancraft-0.3.24.dist-info}/WHEEL +0 -0
- {plancraft-0.3.22.dist-info → plancraft-0.3.24.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -19,8 +19,8 @@ from plancraft.environment.env import (
|
|
19
19
|
get_objective_str,
|
20
20
|
target_and_inventory_to_text_obs,
|
21
21
|
)
|
22
|
-
from plancraft.models.base import PlancraftBaseModel
|
23
|
-
from plancraft.utils import History
|
22
|
+
from plancraft.models.base import PlancraftBaseModel, PlancraftModelOutput
|
23
|
+
from plancraft.utils import HistoryBase, History, HistoryConfig
|
24
24
|
|
25
25
|
|
26
26
|
class Evaluator:
|
@@ -41,40 +41,39 @@ class Evaluator:
|
|
41
41
|
actions: list[ActionHandlerBase] = [MoveActionHandler(), SmeltActionHandler()],
|
42
42
|
output_dir: str = "output",
|
43
43
|
split: str = "val.small",
|
44
|
-
resolution: str = "high",
|
45
44
|
max_steps: int = 30,
|
46
45
|
resume: bool = False,
|
46
|
+
use_fasterrcnn: bool = False,
|
47
47
|
use_multimodal_content_format: bool = False,
|
48
48
|
use_images: bool = False,
|
49
49
|
use_text_inventory: bool = False,
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
prompt_images: list[str] = [],
|
54
|
-
few_shot: bool = True,
|
50
|
+
resolution: str = "high",
|
51
|
+
history_config: Optional[HistoryConfig] = None,
|
52
|
+
history_class: type[HistoryBase] = History,
|
55
53
|
):
|
56
54
|
self.run_name = run_name
|
55
|
+
self.actions = actions
|
56
|
+
self.output_dir = f"{output_dir}/{run_name}/{split}"
|
57
|
+
self.max_steps = max_steps
|
58
|
+
self.resume = resume
|
59
|
+
self.use_fasterrcnn = use_fasterrcnn
|
60
|
+
self.generation_number = 0
|
57
61
|
self.use_multimodal_content_format = use_multimodal_content_format
|
58
62
|
self.use_images = use_images
|
59
63
|
self.use_text_inventory = use_text_inventory
|
60
|
-
self.use_fasterrcnn = use_fasterrcnn
|
61
|
-
self.max_steps = max_steps
|
62
|
-
self.resume = resume
|
63
64
|
self.resolution = resolution
|
64
65
|
|
65
|
-
# history
|
66
|
-
self.
|
67
|
-
self.
|
68
|
-
self.prompt_images = prompt_images
|
69
|
-
self.few_shot = few_shot
|
66
|
+
# Set up history configuration
|
67
|
+
self.history_config = history_config or HistoryConfig()
|
68
|
+
self.history_class = history_class
|
70
69
|
|
71
|
-
|
72
|
-
self.generation_number = 0
|
73
|
-
self.actions = actions
|
74
|
-
|
75
|
-
# load all examples
|
70
|
+
# load examples
|
76
71
|
self.examples: list[PlancraftExample] = self.load_dataset(split)
|
77
72
|
|
73
|
+
def create_history(self) -> HistoryBase:
|
74
|
+
"""Create a new History instance with current configuration"""
|
75
|
+
return self.history_class(actions=self.actions, config=self.history_config)
|
76
|
+
|
78
77
|
def save_results_dict(self, example: PlancraftExample, results_dict: dict):
|
79
78
|
output_dir = f"{self.output_dir}/{self.generation_number}"
|
80
79
|
os.makedirs(output_dir, exist_ok=True)
|
@@ -187,17 +186,7 @@ class Evaluator:
|
|
187
186
|
)
|
188
187
|
|
189
188
|
# initialise history/dialogue tracking
|
190
|
-
history =
|
191
|
-
actions=self.actions,
|
192
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
193
|
-
use_images=self.use_images,
|
194
|
-
use_text_inventory=self.use_text_inventory,
|
195
|
-
resolution=self.resolution,
|
196
|
-
few_shot=self.few_shot,
|
197
|
-
system_prompt=deepcopy(self.system_prompt),
|
198
|
-
prompt_examples=deepcopy(self.prompt_examples),
|
199
|
-
prompt_images=deepcopy(self.prompt_images),
|
200
|
-
)
|
189
|
+
history = self.create_history()
|
201
190
|
|
202
191
|
success = False
|
203
192
|
action = None
|
@@ -235,8 +224,24 @@ class Evaluator:
|
|
235
224
|
history.add_message_to_history(content=observation["message"], role="user")
|
236
225
|
# predict next action
|
237
226
|
raw_action = model.step(observation, dialogue_history=history)
|
238
|
-
|
239
|
-
|
227
|
+
|
228
|
+
# if the model returns a PlancraftModelOutput, extract the action
|
229
|
+
if isinstance(raw_action, PlancraftModelOutput):
|
230
|
+
# add message to history
|
231
|
+
history.add_message_to_history(
|
232
|
+
content=raw_action.action,
|
233
|
+
role="assistant",
|
234
|
+
**(raw_action.kwargs or {}),
|
235
|
+
)
|
236
|
+
raw_action = raw_action.action
|
237
|
+
elif isinstance(raw_action, str):
|
238
|
+
# add message to history
|
239
|
+
history.add_message_to_history(content=raw_action, role="assistant")
|
240
|
+
else:
|
241
|
+
raise ValueError(
|
242
|
+
f"model.step() output must be a string or PlancraftModelOutput, got {type(raw_action)}"
|
243
|
+
)
|
244
|
+
|
240
245
|
# parse the raw action
|
241
246
|
action = self.parse_raw_model_response(
|
242
247
|
raw_action, observation=observation, history=history
|
@@ -267,20 +272,7 @@ class Evaluator:
|
|
267
272
|
for i in range(len(examples))
|
268
273
|
]
|
269
274
|
|
270
|
-
histories = [
|
271
|
-
History(
|
272
|
-
actions=self.actions,
|
273
|
-
use_multimodal_content_format=self.use_multimodal_content_format,
|
274
|
-
use_images=self.use_images,
|
275
|
-
use_text_inventory=self.use_text_inventory,
|
276
|
-
resolution=self.resolution,
|
277
|
-
few_shot=self.few_shot,
|
278
|
-
system_prompt=deepcopy(self.system_prompt),
|
279
|
-
prompt_examples=deepcopy(self.prompt_examples),
|
280
|
-
prompt_images=deepcopy(self.prompt_images),
|
281
|
-
)
|
282
|
-
for _ in range(len(examples))
|
283
|
-
]
|
275
|
+
histories = [self.create_history() for _ in range(len(examples))]
|
284
276
|
|
285
277
|
# Track which environments are still active
|
286
278
|
active_mask = [True for _ in range(len(examples))]
|
@@ -362,14 +354,34 @@ class Evaluator:
|
|
362
354
|
for batch_idx, (idx, raw_action) in enumerate(
|
363
355
|
zip(active_indices, raw_actions)
|
364
356
|
):
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
357
|
+
# if the model returns a PlancraftModelOutput, extract the action
|
358
|
+
if isinstance(raw_action, PlancraftModelOutput):
|
359
|
+
# add message to history
|
360
|
+
histories[idx].add_message_to_history(
|
361
|
+
content=raw_action.action,
|
362
|
+
role="assistant",
|
363
|
+
**(raw_action.kwargs or {}),
|
364
|
+
)
|
365
|
+
actions[idx] = self.parse_raw_model_response(
|
366
|
+
raw_action.action,
|
367
|
+
observation=observations[batch_idx],
|
368
|
+
history=histories[idx],
|
369
|
+
)
|
370
|
+
# if the model returns a string, parse the raw action
|
371
|
+
elif isinstance(raw_action, str):
|
372
|
+
# add message to history
|
373
|
+
histories[idx].add_message_to_history(
|
374
|
+
content=raw_action, role="assistant"
|
375
|
+
)
|
376
|
+
actions[idx] = self.parse_raw_model_response(
|
377
|
+
raw_action,
|
378
|
+
observation=observations[batch_idx],
|
379
|
+
history=histories[idx],
|
380
|
+
)
|
381
|
+
else:
|
382
|
+
raise ValueError(
|
383
|
+
f"model.step() output must be a string or PlancraftModelOutput, got {type(raw_action)}"
|
384
|
+
)
|
373
385
|
|
374
386
|
# Fill in results for environments that didn't finish
|
375
387
|
for i, result in enumerate(results):
|
plancraft/models/base.py
CHANGED
@@ -1,15 +1,26 @@
|
|
1
1
|
import abc
|
2
2
|
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Optional
|
5
|
+
|
3
6
|
from plancraft.utils import History
|
4
7
|
|
5
8
|
|
9
|
+
@dataclass
|
10
|
+
class PlancraftModelOutput:
|
11
|
+
action: str
|
12
|
+
kwargs: Optional[dict] = None
|
13
|
+
|
14
|
+
|
6
15
|
class PlancraftBaseModel(abc.ABC):
|
7
16
|
"""
|
8
17
|
Model class must implement the following methods to work with evaluator
|
9
18
|
"""
|
10
19
|
|
11
20
|
@abc.abstractmethod
|
12
|
-
def step(
|
21
|
+
def step(
|
22
|
+
self, observation: dict, dialogue_history: History
|
23
|
+
) -> PlancraftModelOutput | str:
|
13
24
|
"""
|
14
25
|
Model should output an action in text based on the types available
|
15
26
|
We also pass history to the model to allow for chat models to track the dialogue
|
plancraft/models/dummy.py
CHANGED
@@ -3,7 +3,7 @@ import random
|
|
3
3
|
from plancraft.environment.actions import (
|
4
4
|
MoveAction,
|
5
5
|
)
|
6
|
-
from plancraft.models.base import PlancraftBaseModel
|
6
|
+
from plancraft.models.base import PlancraftBaseModel, PlancraftModelOutput
|
7
7
|
|
8
8
|
|
9
9
|
class DummyModel(PlancraftBaseModel):
|
@@ -38,8 +38,10 @@ class DummyModel(PlancraftBaseModel):
|
|
38
38
|
slot_from=random_slot_from, slot_to=random_slot_to, quantity=1
|
39
39
|
)
|
40
40
|
|
41
|
-
def step(self, observation: dict, **kwargs) ->
|
42
|
-
return str(self.random_select(observation))
|
41
|
+
def step(self, observation: dict, **kwargs) -> PlancraftModelOutput:
|
42
|
+
return PlancraftModelOutput(action=str(self.random_select(observation)))
|
43
43
|
|
44
|
-
def batch_step(
|
44
|
+
def batch_step(
|
45
|
+
self, observations: list[dict], **kwargs
|
46
|
+
) -> list[PlancraftModelOutput]:
|
45
47
|
return [self.step(observation) for observation in observations]
|
plancraft/utils.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import glob
|
2
2
|
import pathlib
|
3
|
-
from collections import Counter
|
4
3
|
from copy import copy
|
5
4
|
from typing import Optional
|
5
|
+
import abc
|
6
|
+
from dataclasses import dataclass, field
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from loguru import logger
|
@@ -15,7 +16,56 @@ from plancraft.environment.prompts import (
|
|
15
16
|
)
|
16
17
|
|
17
18
|
|
18
|
-
|
19
|
+
@dataclass
|
20
|
+
class HistoryConfig:
|
21
|
+
"""Configuration for History instances"""
|
22
|
+
|
23
|
+
few_shot: bool = True
|
24
|
+
system_prompt: Optional[dict] = None
|
25
|
+
prompt_examples: list[dict] = field(default_factory=list)
|
26
|
+
prompt_images: list[str] = field(default_factory=list)
|
27
|
+
|
28
|
+
|
29
|
+
class HistoryBase(abc.ABC):
|
30
|
+
"""Abstract base class defining the interface required by the Evaluator"""
|
31
|
+
|
32
|
+
@property
|
33
|
+
@abc.abstractmethod
|
34
|
+
def num_steps(self) -> int:
|
35
|
+
"""Return the number of interaction steps taken"""
|
36
|
+
pass
|
37
|
+
|
38
|
+
@abc.abstractmethod
|
39
|
+
def add_message_to_history(
|
40
|
+
self, content: str | dict, role: str = "user", **kwargs
|
41
|
+
) -> None:
|
42
|
+
"""Add a message to the dialogue history"""
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abc.abstractmethod
|
46
|
+
def add_observation_to_history(self, observation: dict, **kwargs) -> None:
|
47
|
+
"""Add an observation (inventory, image) to history"""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abc.abstractmethod
|
51
|
+
def trace(self) -> dict:
|
52
|
+
"""Return a traceable history of the interaction"""
|
53
|
+
pass
|
54
|
+
|
55
|
+
@property
|
56
|
+
@abc.abstractmethod
|
57
|
+
def images(self) -> list:
|
58
|
+
"""Return list of images"""
|
59
|
+
pass
|
60
|
+
|
61
|
+
@images.setter
|
62
|
+
@abc.abstractmethod
|
63
|
+
def images(self, value: list) -> None:
|
64
|
+
"""Set list of images"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
|
68
|
+
class History(HistoryBase):
|
19
69
|
"""
|
20
70
|
History class to keep track of dialogue, actions, inventory and images
|
21
71
|
Args:
|
@@ -27,42 +77,40 @@ class History:
|
|
27
77
|
def __init__(
|
28
78
|
self,
|
29
79
|
actions: list[ActionHandlerBase] = [],
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
system_prompt: Optional[dict] = None,
|
36
|
-
prompt_examples: list[dict] = [],
|
37
|
-
prompt_images: list[str] = [],
|
80
|
+
config: HistoryConfig = HistoryConfig(),
|
81
|
+
resolution: str = "high",
|
82
|
+
use_multimodal_content_format: bool = False,
|
83
|
+
use_images: bool = False,
|
84
|
+
use_text_inventory: bool = True,
|
38
85
|
):
|
39
86
|
self.action_handlers = actions
|
40
87
|
self.use_multimodal_content_format = use_multimodal_content_format
|
41
|
-
|
88
|
+
|
42
89
|
self.use_images = use_images
|
43
90
|
self.use_text_inventory = use_text_inventory
|
44
|
-
self.resolution = resolution
|
91
|
+
self.resolution = resolution
|
45
92
|
|
46
93
|
self.inventory_history = []
|
47
94
|
self.tokens_used = 0
|
48
95
|
|
49
96
|
# use system prompt if provided
|
50
|
-
if system_prompt:
|
51
|
-
self.system_prompt_dialogue = system_prompt
|
97
|
+
if config.system_prompt:
|
98
|
+
self.system_prompt_dialogue = config.system_prompt
|
52
99
|
else:
|
53
100
|
# generate system prompt
|
54
101
|
self.system_prompt_dialogue = get_system_prompt(
|
55
102
|
handlers=self.action_handlers,
|
56
103
|
use_multimodal_content_format=self.use_multimodal_content_format,
|
57
104
|
)
|
105
|
+
self.few_shot = config.few_shot
|
58
106
|
|
59
107
|
# set up dialogue history with few-shot prompt
|
60
|
-
self.prompt_examples = prompt_examples
|
61
|
-
self.prompt_images = prompt_images
|
108
|
+
self.prompt_examples = config.prompt_examples
|
109
|
+
self.prompt_images = config.prompt_images
|
62
110
|
self.set_up_few_shot_prompt()
|
63
111
|
|
64
112
|
self.dialogue_history = copy(self.prompt_examples)
|
65
|
-
self.
|
113
|
+
self._images = copy(self.prompt_images)
|
66
114
|
self.initial_dialogue_length = len(self.dialogue_history)
|
67
115
|
|
68
116
|
def set_up_few_shot_prompt(self):
|
@@ -80,7 +128,7 @@ class History:
|
|
80
128
|
if self.use_images:
|
81
129
|
self.prompt_images = load_prompt_images(resolution=self.resolution)
|
82
130
|
|
83
|
-
def add_message_to_history(self, content: str | dict, role="user"):
|
131
|
+
def add_message_to_history(self, content: str | dict, role="user", **kwargs):
|
84
132
|
if isinstance(content, dict):
|
85
133
|
assert "content" in content, "content key not found in message"
|
86
134
|
content["role"] = role
|
@@ -102,9 +150,9 @@ class History:
|
|
102
150
|
self.inventory_history.append(inventory)
|
103
151
|
|
104
152
|
def add_image_to_history(self, image):
|
105
|
-
self.
|
153
|
+
self._images.append(image)
|
106
154
|
|
107
|
-
def add_observation_to_history(self, observation: dict):
|
155
|
+
def add_observation_to_history(self, observation: dict, **kwargs):
|
108
156
|
if observation is None:
|
109
157
|
return
|
110
158
|
if "inventory" in observation:
|
@@ -118,7 +166,7 @@ class History:
|
|
118
166
|
def reset(self):
|
119
167
|
# reset dialogue history to few-shot prompt
|
120
168
|
self.dialogue_history = copy(self.prompt_examples)
|
121
|
-
self.
|
169
|
+
self._images = copy(self.prompt_images)
|
122
170
|
self.initial_dialogue_length = len(self.dialogue_history)
|
123
171
|
|
124
172
|
self.inventory_history = []
|
@@ -138,6 +186,14 @@ class History:
|
|
138
186
|
def num_steps(self):
|
139
187
|
return (len(self.dialogue_history) - self.initial_dialogue_length) // 2
|
140
188
|
|
189
|
+
@property
|
190
|
+
def images(self) -> list:
|
191
|
+
return self._images
|
192
|
+
|
193
|
+
@images.setter
|
194
|
+
def images(self, value: list) -> None:
|
195
|
+
self._images = value
|
196
|
+
|
141
197
|
|
142
198
|
def get_downloaded_models() -> dict:
|
143
199
|
"""
|
@@ -1,8 +1,8 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=dyszVJtTc_PThVEeGmp6YMkmEn4gaXQW52eWaKO2FQ8,17210
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
|
-
plancraft/utils.py,sha256=
|
5
|
+
plancraft/utils.py,sha256=VhnxMihh6pRhNjQTK5HDc0FYWmF9_EcQyRP_a7fbIZA,7156
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
7
7
|
plancraft/data/test.small.easy.json,sha256=5NZEJ2PqIgmHQecJOIVQyM1D6GFKyJq7GVmgRudaqQk,189304
|
8
8
|
plancraft/data/test.small.json,sha256=eULAG1rdolRMXPrecV-7YoDIheKGyIT5MVpWdISV0wg,270089
|
@@ -1913,14 +1913,14 @@ plancraft/environment/tags/wooden_trapdoors.json,sha256=DbjfwoHJL8VuYWV61A1uDqW7
|
|
1913
1913
|
plancraft/environment/tags/wool.json,sha256=Z59l4mdPztVZBFaglJ4mV9H2OnyCVzhqQRi2dduak78,496
|
1914
1914
|
plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,510
|
1915
1915
|
plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
|
1916
|
-
plancraft/models/base.py,sha256=
|
1916
|
+
plancraft/models/base.py,sha256=S8EdkqWpn8nE1WcrqDoA4Hx4p52qEttGxnqjIPWvl3Q,852
|
1917
1917
|
plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
|
1918
|
-
plancraft/models/dummy.py,sha256=
|
1918
|
+
plancraft/models/dummy.py,sha256=_NUTviv5ye6KGzODRt0Zykk8shsek0QBqWCeZW3ldSQ,1495
|
1919
1919
|
plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
|
1920
1920
|
plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.24.dist-info/METADATA,sha256=yCmPq3zXC2cEu5E9c1MOT1nsrfGcI_clOMYIV4GQoR4,11148
|
1924
|
+
plancraft-0.3.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.24.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.24.dist-info/RECORD,,
|
File without changes
|
File without changes
|