plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
plancraft/__init__.py ADDED
File without changes
plancraft/config.py ADDED
@@ -0,0 +1,155 @@
1
+ from typing import Literal, Optional, Union
2
+
3
+ from pydantic import BaseModel, model_validator
4
+
5
+ try:
6
+ from plancraft.environments.recipes import RECIPES
7
+ except ImportError:
8
+ RECIPES = {}
9
+
10
+ DatasetSplit = Literal[
11
+ "train", "val", "val.small", "val.small.easy", "test", "test.small"
12
+ ]
13
+
14
+
15
+ class EnvironmentConfig(BaseModel):
16
+ symbolic: bool
17
+ symbolic_observation_space: bool
18
+ symbolic_action_space: bool
19
+ preferred_spawn_biome: str = "plains"
20
+ resolution: list[int] = [512, 512]
21
+
22
+
23
+ class PlancraftConfig(BaseModel):
24
+ model: str
25
+ adapter: str = ""
26
+ tokenizer: str
27
+ num_generations: int
28
+ mode: Literal["react", "act", "oracle", "dummy"] = "react"
29
+ output_dir: str
30
+ max_steps: int = 30 # max number of steps (smelt/move) to take in the environment before stopping
31
+ quantize: Literal[False, "int4", "int8"]
32
+ environment: EnvironmentConfig
33
+ split: DatasetSplit = "val.small"
34
+ max_message_window: int = 30 # max number of messages to keep in dialogue history (30 is around 8k llama3 tokens)
35
+ hot_cache: bool = True # whether to cache the dialogue history between steps
36
+ resume: bool = True # resume inference
37
+ few_shot: bool = True # whether to use few-shot prompt
38
+ system_prompt: bool = True # whether to use system prompt
39
+ valid_actions: list[str] = ["move", "smelt", "think", "search", "impossible"]
40
+ use_maskrcnn: bool = False # whether to use maskrcnn for multimodal parsing
41
+
42
+ # observations
43
+ use_text_inventory: bool = True # whether to include inventory in text
44
+ use_images: bool = False # whether to include images in multimodal content
45
+ use_multimodal_content_format: bool = (
46
+ False # whether to use multimodal content format
47
+ )
48
+
49
+ @model_validator(mode="after")
50
+ def validate(self):
51
+ assert set(
52
+ self.valid_actions
53
+ ).issubset(
54
+ {"move", "smelt", "think", "search", "impossible"}
55
+ ), "valid_actions should be subset of {'move', 'smelt', 'think', 'search', 'impossible'}"
56
+
57
+ if self.use_images:
58
+ assert (
59
+ not self.environment.symbolic
60
+ ), "Set environment.symbolic to False when using images"
61
+
62
+ return self
63
+
64
+
65
+ class WandbConfig(BaseModel):
66
+ project: str
67
+ entity: str
68
+ mode: str
69
+
70
+
71
+ class LaunchConfig(BaseModel):
72
+ command: str
73
+ job_name: str
74
+ gpu_limit: int
75
+ gpu_product: str
76
+ cpu_request: int
77
+ ram_request: str
78
+ interactive: bool = False
79
+ namespace: str = "informatics"
80
+ env_vars: dict[str, dict[str, str]]
81
+
82
+
83
+ class EvalConfig(BaseModel):
84
+ plancraft: PlancraftConfig
85
+ wandb: WandbConfig
86
+ launch: LaunchConfig
87
+
88
+
89
+ class TrainingArgs(BaseModel):
90
+ base_model: str = "llama3"
91
+ trace_mode: str = "oa"
92
+ push_to_hub: bool = False
93
+
94
+ # uses less space but not working with multi-gpu training..
95
+ qlora: bool = False
96
+
97
+ lora_alpha: int = 16
98
+ lora_dropout: float = 0.1
99
+ lora_r: int = 64
100
+ # training data args
101
+ seed: int = 42
102
+ # model args
103
+ batch_size: int = 1
104
+ max_seq_length: int = 8142
105
+ max_message_window: int = 100
106
+ only_assistant: bool = True
107
+
108
+ # training args
109
+ gradient_accumulation_steps: int = 4
110
+ learning_rate: float = 2e-4
111
+ max_grad_norm: float = 0.3
112
+ warmup_ratio: float = 0.03
113
+ num_train_epochs: int = 3
114
+ num_workers: int = 1
115
+
116
+
117
+ class TrainConfig(BaseModel):
118
+ training: TrainingArgs
119
+ wandb: WandbConfig
120
+ launch: LaunchConfig
121
+
122
+
123
+ class PlancraftExample(BaseModel):
124
+ target: str
125
+ inventory: dict[str, int]
126
+ slotted_inventory: list[dict[str, Union[str, int]]]
127
+ num_distractors: int
128
+ impossible: bool
129
+ optimal_path_length: Optional[int]
130
+ optimal_path: Optional[list[str]]
131
+ inventory_trace: Optional[list[dict[str, int]]]
132
+ items_used: Optional[int]
133
+ unique_items_used: Optional[int]
134
+ complexity: Optional[int]
135
+ complexity_bin: int
136
+ unseen_in_train: bool
137
+ unseen_in_val: bool
138
+ split: DatasetSplit
139
+ id: str
140
+
141
+ recipe_type: Optional[str] = ""
142
+
143
+ # post processing set recipe type
144
+ def model_post_init(self, __context):
145
+ recipe_types = set()
146
+ if self.optimal_path is None:
147
+ self.recipe_type = "impossible"
148
+ return
149
+ for step in self.optimal_path:
150
+ for r in RECIPES[step]:
151
+ recipe_types.add(r.recipe_type)
152
+ if len(recipe_types) == 1:
153
+ self.recipe_type = recipe_types.pop()
154
+ else:
155
+ self.recipe_type = "mixed"
File without changes
@@ -0,0 +1,218 @@
1
+ from typing import Union
2
+
3
+ from pydantic import BaseModel, field_validator, model_validator
4
+
5
+
6
+ def convert_to_slot_index(slot: str) -> int:
7
+ slot = slot.strip()
8
+ grid_map = {
9
+ "[0]": 0,
10
+ "[A1]": 1,
11
+ "[A2]": 2,
12
+ "[A3]": 3,
13
+ "[B1]": 4,
14
+ "[B2]": 5,
15
+ "[B3]": 6,
16
+ "[C1]": 7,
17
+ "[C2]": 8,
18
+ "[C3]": 9,
19
+ }
20
+ if slot in grid_map:
21
+ return grid_map[slot]
22
+ else:
23
+ return int(slot[2:-1]) + 9
24
+
25
+
26
+ def convert_from_slot_index(slot_index: int) -> str:
27
+ grid_map = {
28
+ 0: "[0]",
29
+ 1: "[A1]",
30
+ 2: "[A2]",
31
+ 3: "[A3]",
32
+ 4: "[B1]",
33
+ 5: "[B2]",
34
+ 6: "[B3]",
35
+ 7: "[C1]",
36
+ 8: "[C2]",
37
+ 9: "[C3]",
38
+ }
39
+ if slot_index < 10:
40
+ return grid_map[slot_index]
41
+ else:
42
+ return f"[I{slot_index-9}]"
43
+
44
+
45
+ class SymbolicMoveAction(BaseModel):
46
+ """ "Moves an item from one slot to another"""
47
+
48
+ slot_from: int
49
+ slot_to: int
50
+ quantity: int
51
+ action_type: str = "move"
52
+
53
+ @field_validator("action_type", mode="before")
54
+ def fix_action_type(cls, value) -> str:
55
+ return "move"
56
+
57
+ @field_validator("slot_from", "slot_to", mode="before")
58
+ def transform_str_to_int(cls, value) -> int:
59
+ # if value is a string like [A1] or [I1], convert it to an integer
60
+ if isinstance(value, str):
61
+ try:
62
+ return convert_to_slot_index(value)
63
+ except ValueError:
64
+ raise AttributeError(
65
+ "slot_from and slot_to must be [0] or [A1] to [C3] or [I1] to [I36]"
66
+ )
67
+ return value
68
+
69
+ @field_validator("quantity", mode="before")
70
+ def transform_quantity(cls, value) -> int:
71
+ if isinstance(value, str):
72
+ try:
73
+ return int(value)
74
+ except ValueError:
75
+ raise AttributeError("quantity must be an integer")
76
+ return value
77
+
78
+ @model_validator(mode="after")
79
+ def validate(self):
80
+ if self.slot_from == self.slot_to:
81
+ raise AttributeError("slot_from and slot_to must be different")
82
+ if self.slot_from < 0 or self.slot_from > 45:
83
+ raise AttributeError("slot_from must be between 0 and 45")
84
+ if self.slot_to < 1 or self.slot_to > 45:
85
+ raise AttributeError("slot_to must be between 1 and 45")
86
+ if self.quantity < 1 or self.quantity > 64:
87
+ raise AttributeError("quantity must be between 1 and 64")
88
+
89
+ def to_action_dict(self) -> dict:
90
+ return {
91
+ "inventory_command": [self.slot_from, self.slot_to, self.quantity],
92
+ }
93
+
94
+
95
+ class SymbolicSmeltAction(BaseModel):
96
+ """Smelts an item and moves the result into a new slot"""
97
+
98
+ slot_from: int
99
+ slot_to: int
100
+ quantity: int
101
+ action_type: str = "smelt"
102
+
103
+ @field_validator("action_type", mode="before")
104
+ def fix_action_type(cls, value) -> str:
105
+ return "smelt"
106
+
107
+ @field_validator("slot_from", "slot_to", mode="before")
108
+ def transform_str_to_int(cls, value) -> int:
109
+ # if value is a string like [A1] or [I1], convert it to an integer
110
+ if isinstance(value, str):
111
+ try:
112
+ return convert_to_slot_index(value)
113
+ except ValueError:
114
+ raise AttributeError(
115
+ "slot_from and slot_to must be [0] or [A1] to [C3] or [I1] to [I36]"
116
+ )
117
+ return value
118
+
119
+ @field_validator("quantity", mode="before")
120
+ def transform_quantity(cls, value) -> int:
121
+ if isinstance(value, str):
122
+ try:
123
+ return int(value)
124
+ except ValueError:
125
+ raise AttributeError("quantity must be an integer")
126
+ return value
127
+
128
+ @model_validator(mode="after")
129
+ def validate(self):
130
+ if self.slot_from == self.slot_to:
131
+ raise AttributeError("slot_from and slot_to must be different")
132
+ if self.slot_from < 0 or self.slot_from > 45:
133
+ raise AttributeError("slot_from must be between 0 and 45")
134
+ if self.slot_to < 1 or self.slot_to > 45:
135
+ raise AttributeError("slot_to must be between 1 and 45")
136
+ if self.quantity < 1 or self.quantity > 64:
137
+ raise AttributeError("quantity must be between 1 and 64")
138
+
139
+ def to_action_dict(self) -> dict:
140
+ return {
141
+ "smelt": [self.slot_from, self.slot_to, self.quantity],
142
+ }
143
+
144
+
145
+ class ThinkAction(BaseModel):
146
+ """Think about the answer before answering"""
147
+
148
+ thought: str
149
+
150
+ def to_action_dict(self) -> dict:
151
+ return {}
152
+
153
+
154
+ class SearchAction(BaseModel):
155
+ """Searches for a relevant document in the wiki"""
156
+
157
+ search_string: str
158
+
159
+ def to_action_dict(self) -> dict:
160
+ return {
161
+ "search": self.search_string,
162
+ }
163
+
164
+
165
+ class RealActionInteraction(BaseModel):
166
+ mouse_direction_x: float = 0
167
+ mouse_direction_y: float = 0
168
+ right_click: bool = False
169
+ left_click: bool = False
170
+
171
+ @field_validator("mouse_direction_x", "mouse_direction_y")
172
+ def prevent_zero(cls, v):
173
+ if v > 10:
174
+ return 10
175
+ elif v < -10:
176
+ return -10
177
+ return v
178
+
179
+ def to_action_dict(self) -> dict:
180
+ return {
181
+ "camera": [self.mouse_direction_x, self.mouse_direction_y],
182
+ "use": int(self.right_click),
183
+ "attack": int(self.left_click),
184
+ }
185
+
186
+
187
+ class StopAction(BaseModel):
188
+ """
189
+ Action that model can take to stop planning - decide impossible to continue
190
+ Note: also known as the "impossible" action
191
+ """
192
+
193
+ reason: str = ""
194
+
195
+
196
+ class NoOp(SymbolicMoveAction):
197
+ """No operation action - special instance of move"""
198
+
199
+ def __init__(self):
200
+ super().__init__(slot_from=0, slot_to=1, quantity=1)
201
+ self.slot_to = 0
202
+
203
+ def __call__(self, *args, **kwargs):
204
+ return None
205
+
206
+ def __str__(self):
207
+ return "NoOp"
208
+
209
+
210
+ # when symbolic action is true, can either move objects around or smelt
211
+ SymbolicAction = SymbolicMoveAction # | SymbolicSmeltAction
212
+
213
+ # when symbolic action is false, then need to use mouse to move things around, but can use smelt action
214
+ RealAction = RealActionInteraction | SymbolicSmeltAction
215
+
216
+
217
+ class PydanticSymbolicAction(BaseModel):
218
+ root: Union[SymbolicMoveAction, SymbolicSmeltAction]
@@ -0,0 +1,316 @@
1
+ from typing import Sequence, Union
2
+
3
+ import numpy as np
4
+ import json
5
+
6
+ from loguru import logger
7
+
8
+ from plancraft.environments.actions import RealAction
9
+
10
+ try:
11
+ from minerl.env import _singleagent
12
+ from minerl.herobraine.env_specs.human_controls import HumanControlEnvSpec
13
+ from minerl.herobraine.hero import handlers, mc, spaces
14
+ from minerl.herobraine.hero.handler import Handler
15
+ from minerl.herobraine.hero.handlers.agent.action import Action
16
+ from minerl.herobraine.hero.handlers.agent.start import InventoryAgentStart
17
+ from minerl.herobraine.hero.handlers.translation import TranslationHandler
18
+
19
+ class InventoryCommandAction(Action):
20
+ """
21
+ Handler which lets agents programmatically interact with an open container
22
+
23
+ Using this - agents can move a chosen quantity of items from one slot to another.
24
+ """
25
+
26
+ def to_string(self):
27
+ return "inventory_command"
28
+
29
+ def xml_template(self) -> str:
30
+ return str("<InventoryCommands/>")
31
+
32
+ def __init__(self):
33
+ self._command = "inventory_command"
34
+ # first argument is the slot to take from
35
+ # second is the slot to put into
36
+ # third is the count to take
37
+ super().__init__(
38
+ self.command,
39
+ spaces.Tuple(
40
+ (
41
+ spaces.Discrete(46),
42
+ spaces.Discrete(46),
43
+ spaces.Discrete(64),
44
+ )
45
+ ),
46
+ )
47
+
48
+ def from_universal(self, x):
49
+ return np.array([0, 0, 0], dtype=np.int32)
50
+
51
+ class SmeltCommandAction(Action):
52
+ """
53
+ An action handler for smelting an item
54
+ We assume smelting is immediate.
55
+ @TODO: might be interesting to explore using the smelting time as an additional planning parameter.
56
+
57
+ Using this agents can smelt items in their inventory.
58
+ """
59
+
60
+ def __init__(self):
61
+ self._command = "smelt"
62
+ # first argument is the slot to take from
63
+ # second is the slot to put into
64
+ # third is the count to smelt
65
+ super().__init__(
66
+ self.command,
67
+ spaces.Tuple(
68
+ (
69
+ spaces.Discrete(46),
70
+ spaces.Discrete(46),
71
+ spaces.Discrete(64),
72
+ )
73
+ ),
74
+ )
75
+
76
+ def to_string(self):
77
+ return "smelt"
78
+
79
+ def xml_template(self) -> str:
80
+ return str("<SmeltCommands/>")
81
+
82
+ def from_universal(self, x):
83
+ return np.array([0, 0, 0], dtype=np.int32)
84
+
85
+ class InventoryResetAction(Action):
86
+ def __init__(self):
87
+ self._command = "inventory_reset"
88
+ super().__init__(self._command, spaces.Text([1]))
89
+
90
+ def to_string(self) -> str:
91
+ return "inventory_reset"
92
+
93
+ def to_hero(self, inventory_items: list[dict]):
94
+ return "{} {}".format(self._command, json.dumps(inventory_items))
95
+
96
+ def xml_template(self) -> str:
97
+ return "<InventoryResetCommands/>"
98
+
99
+ def from_universal(self, x):
100
+ return []
101
+
102
+ MINUTE = 20 * 60
103
+
104
+ class CustomInventoryAgentStart(InventoryAgentStart):
105
+ def __init__(self, inventory: list[dict[str, Union[str, int]]]):
106
+ super().__init__({item["slot"]: item for item in inventory})
107
+
108
+ class CraftingTableOnly(Handler):
109
+ def to_string(self):
110
+ return "start_with_crafting_table"
111
+
112
+ def xml_template(self) -> str:
113
+ return "<CraftingTableOnly>true</CraftingTableOnly>"
114
+
115
+ class InventoryObservation(TranslationHandler):
116
+ """
117
+ Handles GUI Workbench Observations for selected items
118
+ """
119
+
120
+ def to_string(self):
121
+ return "inventory"
122
+
123
+ def xml_template(self) -> str:
124
+ return str("""<ObservationFromFullInventory flat="false"/>""")
125
+
126
+ def __init__(self, item_list, _other="other"):
127
+ item_list = sorted(item_list)
128
+ super().__init__(
129
+ spaces.Dict(
130
+ spaces={
131
+ k: spaces.Box(
132
+ low=0,
133
+ high=2304,
134
+ shape=(),
135
+ dtype=np.int32,
136
+ normalizer_scale="log",
137
+ )
138
+ for k in item_list
139
+ }
140
+ )
141
+ )
142
+ self.num_items = len(item_list)
143
+ self.items = item_list
144
+
145
+ def add_to_mission_spec(self, mission_spec):
146
+ pass
147
+
148
+ def from_hero(self, info):
149
+ return info["inventory"]
150
+
151
+ def from_universal(self, obs):
152
+ raise NotImplementedError(
153
+ "from_universal not implemented in InventoryObservation"
154
+ )
155
+
156
+ class PlancraftBaseEnvSpec(HumanControlEnvSpec):
157
+ def __init__(
158
+ self,
159
+ symbolic_action_space=False,
160
+ symbolic_observation_space=False,
161
+ max_episode_steps=2 * MINUTE,
162
+ inventory: Sequence[dict] = (),
163
+ preferred_spawn_biome: str = "plains",
164
+ resolution=[260, 180],
165
+ ):
166
+ self.inventory = inventory
167
+ self.preferred_spawn_biome = preferred_spawn_biome
168
+ self.symbolic_action_space = symbolic_action_space
169
+ self.symbolic_observation_space = symbolic_observation_space
170
+
171
+ mode = "real"
172
+ if symbolic_action_space:
173
+ mode += "-symbolic-act"
174
+ else:
175
+ mode += "-real-act"
176
+
177
+ if symbolic_observation_space:
178
+ mode += "-symbolic-obs"
179
+
180
+ if symbolic_action_space:
181
+ cursor_size = 1
182
+ else:
183
+ cursor_size = 16
184
+
185
+ name = f"plancraft-{mode}-v0"
186
+ super().__init__(
187
+ name=name,
188
+ max_episode_steps=max_episode_steps,
189
+ resolution=resolution,
190
+ cursor_size_range=[cursor_size, cursor_size],
191
+ )
192
+
193
+ def create_agent_start(self) -> list[Handler]:
194
+ base_agent_start_handlers = super().create_agent_start()
195
+ return base_agent_start_handlers + [
196
+ CustomInventoryAgentStart(self.inventory),
197
+ handlers.PreferredSpawnBiome(self.preferred_spawn_biome),
198
+ handlers.DoneOnDeath(),
199
+ CraftingTableOnly(),
200
+ ]
201
+
202
+ def create_observables(self) -> list[TranslationHandler]:
203
+ if self.symbolic_observation_space:
204
+ return [
205
+ handlers.POVObservation(self.resolution),
206
+ InventoryObservation([item["slot"] for item in self.inventory]),
207
+ ]
208
+ return [handlers.POVObservation(self.resolution)]
209
+
210
+ def create_server_world_generators(self) -> list[Handler]:
211
+ # TODO the original biome forced is not implemented yet. Use this for now.
212
+ return [handlers.DefaultWorldGenerator(force_reset=True)]
213
+
214
+ def create_server_quit_producers(self) -> list[Handler]:
215
+ return [
216
+ handlers.ServerQuitFromTimeUp(
217
+ (self.max_episode_steps * mc.MS_PER_STEP)
218
+ ),
219
+ handlers.ServerQuitWhenAnyAgentFinishes(),
220
+ ]
221
+
222
+ def create_server_initial_conditions(self) -> list[Handler]:
223
+ return [
224
+ handlers.TimeInitialCondition(allow_passage_of_time=False),
225
+ handlers.SpawningInitialCondition(allow_spawning=True),
226
+ ]
227
+
228
+ def create_actionables(self) -> list[TranslationHandler]:
229
+ """
230
+ Symbolic env can move items around in the inventory using function
231
+ Real env can use camera/keyboard
232
+ """
233
+ # Camera and mouse
234
+ if self.symbolic_action_space:
235
+ return [
236
+ InventoryCommandAction(),
237
+ SmeltCommandAction(),
238
+ InventoryResetAction(),
239
+ ]
240
+ return [
241
+ handlers.KeybasedCommandAction(v, v) for k, v in mc.KEYMAP.items()
242
+ ] + [
243
+ handlers.CameraAction(),
244
+ SmeltCommandAction(),
245
+ InventoryResetAction(),
246
+ ]
247
+
248
+ def is_from_folder(self, folder: str) -> bool:
249
+ return False
250
+
251
+ def create_agent_handlers(self) -> list[Handler]:
252
+ return []
253
+
254
+ def create_mission_handlers(self):
255
+ return []
256
+
257
+ def create_monitors(self):
258
+ return []
259
+
260
+ def create_rewardables(self):
261
+ return []
262
+
263
+ def create_server_decorators(self) -> list[Handler]:
264
+ return []
265
+
266
+ def determine_success_from_rewards(self, rewards: list) -> bool:
267
+ return False
268
+
269
+ def get_docstring(self):
270
+ return self.__class__.__doc__
271
+
272
+ class RealPlancraft(_singleagent._SingleAgentEnv):
273
+ def __init__(
274
+ self,
275
+ inventory: list[dict],
276
+ preferred_spawn_biome="plains",
277
+ symbolic_action_space=False,
278
+ symbolic_observation_space=True,
279
+ resolution=[512, 512],
280
+ crop=True,
281
+ ):
282
+ # NOTE: crop is only supported for resolution 512x512 (default)
283
+ self.crop = crop
284
+ self.resolution = resolution
285
+ env_spec = PlancraftBaseEnvSpec(
286
+ symbolic_action_space=symbolic_action_space,
287
+ symbolic_observation_space=symbolic_observation_space,
288
+ preferred_spawn_biome=preferred_spawn_biome,
289
+ inventory=inventory,
290
+ resolution=resolution,
291
+ )
292
+ super(RealPlancraft, self).__init__(env_spec=env_spec)
293
+ self.reset()
294
+
295
+ def step(self, action: RealAction | dict):
296
+ if not isinstance(action, dict):
297
+ action = action.to_action_dict()
298
+ obs, rew, done, info = super().step(action)
299
+ if "pov" in obs and self.crop and self.resolution == [512, 512]:
300
+ # crop at position x=174, y=170 with width=164 and height=173
301
+ obs["pov"] = obs["pov"][174 : 174 + 164, 170 : 168 + 173]
302
+ return obs, rew, done, info
303
+
304
+ def fast_reset(self, new_inventory: list[dict]):
305
+ super().step({"inventory_reset": new_inventory})
306
+
307
+
308
+ except ImportError:
309
+
310
+ class RealPlancraft:
311
+ def __init__(self, *args, **kwargs):
312
+ logger.warning(
313
+ "The 'minerl' package is required to use RealPlancraft. "
314
+ "Please install it using 'pip install plancraft[full]' or 'pip install minerl'."
315
+ )
316
+ raise ImportError("minerl package not found")