mettagrid 0.2.0.3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mettagrid might be problematic. Click here for more details.

Files changed (105) hide show
  1. mettagrid/__init__.py +62 -0
  2. mettagrid/builder/__init__.py +7 -0
  3. mettagrid/builder/building.py +136 -0
  4. mettagrid/builder/empty_converters.py +85 -0
  5. mettagrid/builder/envs.py +221 -0
  6. mettagrid/config/__init__.py +35 -0
  7. mettagrid/config/config.py +106 -0
  8. mettagrid/config/mettagrid_c_config.py +342 -0
  9. mettagrid/config/mettagrid_config.py +223 -0
  10. mettagrid/core.py +291 -0
  11. mettagrid/demo.py +35 -0
  12. mettagrid/envs/__init__.py +13 -0
  13. mettagrid/envs/gym_env.py +127 -0
  14. mettagrid/envs/gym_wrapper.py +63 -0
  15. mettagrid/envs/mettagrid_env.py +304 -0
  16. mettagrid/envs/pettingzoo_env.py +213 -0
  17. mettagrid/envs/puffer_base.py +143 -0
  18. mettagrid/map_builder/__init__.py +10 -0
  19. mettagrid/map_builder/ascii.py +38 -0
  20. mettagrid/map_builder/map_builder.py +147 -0
  21. mettagrid/map_builder/maze.py +122 -0
  22. mettagrid/map_builder/perimeter_incontext.py +296 -0
  23. mettagrid/map_builder/random.py +95 -0
  24. mettagrid/map_builder/utils.py +94 -0
  25. mettagrid/mapgen/__init__.py +0 -0
  26. mettagrid/mapgen/load.py +35 -0
  27. mettagrid/mapgen/mapgen.py +322 -0
  28. mettagrid/mapgen/mapgen_ascii.py +51 -0
  29. mettagrid/mapgen/random/float.py +99 -0
  30. mettagrid/mapgen/random/int.py +50 -0
  31. mettagrid/mapgen/scene.py +323 -0
  32. mettagrid/mapgen/scenes/ascii.py +41 -0
  33. mettagrid/mapgen/scenes/auto.py +149 -0
  34. mettagrid/mapgen/scenes/bsp.py +417 -0
  35. mettagrid/mapgen/scenes/convchain.py +180 -0
  36. mettagrid/mapgen/scenes/copy_grid.py +36 -0
  37. mettagrid/mapgen/scenes/grid_altars.py +118 -0
  38. mettagrid/mapgen/scenes/inline_ascii.py +42 -0
  39. mettagrid/mapgen/scenes/layout.py +31 -0
  40. mettagrid/mapgen/scenes/make_connected.py +163 -0
  41. mettagrid/mapgen/scenes/maze.py +212 -0
  42. mettagrid/mapgen/scenes/mean_distance.py +48 -0
  43. mettagrid/mapgen/scenes/mirror.py +120 -0
  44. mettagrid/mapgen/scenes/multi_left_and_right.py +117 -0
  45. mettagrid/mapgen/scenes/nop.py +15 -0
  46. mettagrid/mapgen/scenes/radial_maze.py +50 -0
  47. mettagrid/mapgen/scenes/random.py +65 -0
  48. mettagrid/mapgen/scenes/random_dcss_scene.py +48 -0
  49. mettagrid/mapgen/scenes/random_objects.py +35 -0
  50. mettagrid/mapgen/scenes/random_scene.py +31 -0
  51. mettagrid/mapgen/scenes/random_yaml_scene.py +32 -0
  52. mettagrid/mapgen/scenes/remove_agents.py +26 -0
  53. mettagrid/mapgen/scenes/room_grid.py +77 -0
  54. mettagrid/mapgen/scenes/spiral.py +104 -0
  55. mettagrid/mapgen/scenes/transplant_scene.py +42 -0
  56. mettagrid/mapgen/scenes/varied_terrain.py +376 -0
  57. mettagrid/mapgen/scenes/wfc.py +269 -0
  58. mettagrid/mapgen/scenes/yaml.py +22 -0
  59. mettagrid/mapgen/tools/dcss_import.py +123 -0
  60. mettagrid/mapgen/tools/gen.py +94 -0
  61. mettagrid/mapgen/tools/gen_scene.py +51 -0
  62. mettagrid/mapgen/tools/view.py +25 -0
  63. mettagrid/mapgen/types.py +81 -0
  64. mettagrid/mapgen/utils/ascii_grid.py +55 -0
  65. mettagrid/mapgen/utils/draw.py +32 -0
  66. mettagrid/mapgen/utils/make_scene_config.py +36 -0
  67. mettagrid/mapgen/utils/pattern.py +181 -0
  68. mettagrid/mapgen/utils/s3utils.py +44 -0
  69. mettagrid/mapgen/utils/show.py +21 -0
  70. mettagrid/mapgen/utils/storable_map.py +117 -0
  71. mettagrid/mapgen/utils/storable_map_index.py +92 -0
  72. mettagrid/mapgen/utils/thumbnail.py +360 -0
  73. mettagrid/mettagrid_c.pyi +232 -0
  74. mettagrid/mettagrid_c.so +0 -0
  75. mettagrid/profiling/__init__.py +1 -0
  76. mettagrid/profiling/memory_monitor.py +242 -0
  77. mettagrid/profiling/stopwatch.py +711 -0
  78. mettagrid/profiling/system_monitor.py +355 -0
  79. mettagrid/py.typed +0 -0
  80. mettagrid/renderer/miniscope.py +138 -0
  81. mettagrid/renderer/nethack.py +101 -0
  82. mettagrid/test_support/__init__.py +7 -0
  83. mettagrid/test_support/actions.py +426 -0
  84. mettagrid/test_support/mapgen.py +77 -0
  85. mettagrid/test_support/observation_helper.py +50 -0
  86. mettagrid/test_support/orientation.py +100 -0
  87. mettagrid/test_support/token_types.py +18 -0
  88. mettagrid/util/__init__.py +0 -0
  89. mettagrid/util/char_encoder.py +55 -0
  90. mettagrid/util/debug.py +501 -0
  91. mettagrid/util/dict_utils.py +10 -0
  92. mettagrid/util/diversity.py +71 -0
  93. mettagrid/util/episode_stats_db.py +157 -0
  94. mettagrid/util/file.py +503 -0
  95. mettagrid/util/grid_object_formatter.py +95 -0
  96. mettagrid/util/module.py +12 -0
  97. mettagrid/util/replay_writer.py +136 -0
  98. mettagrid/util/stats_writer.py +55 -0
  99. mettagrid/util/uri.py +162 -0
  100. mettagrid-0.2.0.3.dist-info/METADATA +256 -0
  101. mettagrid-0.2.0.3.dist-info/RECORD +105 -0
  102. mettagrid-0.2.0.3.dist-info/WHEEL +5 -0
  103. mettagrid-0.2.0.3.dist-info/entry_points.txt +2 -0
  104. mettagrid-0.2.0.3.dist-info/licenses/LICENSE +21 -0
  105. mettagrid-0.2.0.3.dist-info/top_level.txt +1 -0
mettagrid/__init__.py ADDED
@@ -0,0 +1,62 @@
1
+ """
2
+ MettaGrid - Multi-agent reinforcement learning grid environments.
3
+
4
+ This module provides various environment adapters for different RL frameworks:
5
+ - MettaGridCore: Core C++ wrapper (no training features)
6
+ - MettaGridEnv: Training environment (PufferLib-based with stats/replay)
7
+ - MettaGridGymEnv: Gymnasium adapter
8
+ - MettaGridPettingZooEnv: PettingZoo adapter
9
+
10
+ All adapters inherit from MettaGridCore and provide framework-specific interfaces.
11
+ For PufferLib integration, use PufferLib's MettaPuff wrapper directly.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from mettagrid.config.mettagrid_config import MettaGridConfig
17
+
18
+ # Import environment classes
19
+ from mettagrid.core import MettaGridCore
20
+
21
+ # Import other commonly used classes
22
+ from mettagrid.envs.gym_env import MettaGridGymEnv
23
+ from mettagrid.envs.mettagrid_env import MettaGridEnv
24
+ from mettagrid.envs.pettingzoo_env import MettaGridPettingZooEnv
25
+ from mettagrid.map_builder.map_builder import GameMap
26
+
27
+ # Import data types from C++ module (source of truth)
28
+ from mettagrid.mettagrid_c import (
29
+ dtype_actions,
30
+ dtype_masks,
31
+ dtype_observations,
32
+ dtype_rewards,
33
+ dtype_success,
34
+ dtype_terminals,
35
+ dtype_truncations,
36
+ )
37
+ from mettagrid.util.replay_writer import ReplayWriter
38
+ from mettagrid.util.stats_writer import StatsWriter
39
+
40
+ __all__ = [
41
+ # Config
42
+ "MettaGridConfig",
43
+ # Core classes
44
+ "MettaGridCore",
45
+ # Main environment (backward compatible)
46
+ "MettaGridEnv",
47
+ # Environment adapters
48
+ "MettaGridGymEnv",
49
+ "MettaGridPettingZooEnv",
50
+ # Data types
51
+ "dtype_actions",
52
+ "dtype_observations",
53
+ "dtype_rewards",
54
+ "dtype_terminals",
55
+ "dtype_truncations",
56
+ "dtype_masks",
57
+ "dtype_success",
58
+ # Supporting classes
59
+ "GameMap",
60
+ "ReplayWriter",
61
+ "StatsWriter",
62
+ ]
@@ -0,0 +1,7 @@
1
+ """
2
+ Configuration builders for Metta environments.
3
+ """
4
+
5
+ from . import building, empty_converters, envs
6
+
7
+ __all__ = ["building", "envs", "empty_converters"]
@@ -0,0 +1,136 @@
1
+ from mettagrid.config.mettagrid_config import AssemblerConfig, ConverterConfig, RecipeConfig, WallConfig
2
+
3
+ wall = WallConfig(type_id=1)
4
+ block = WallConfig(type_id=14, swappable=True)
5
+
6
+ altar = ConverterConfig(
7
+ type_id=8,
8
+ input_resources={"battery_red": 3},
9
+ output_resources={"heart": 1},
10
+ cooldown=10,
11
+ )
12
+
13
+
14
+ def make_mine(color: str, type_id: int) -> ConverterConfig:
15
+ return ConverterConfig(
16
+ type_id=type_id,
17
+ output_resources={f"ore_{color}": 1},
18
+ cooldown=50,
19
+ )
20
+
21
+
22
+ mine_red = make_mine("red", 2)
23
+ mine_blue = make_mine("blue", 3)
24
+ mine_green = make_mine("green", 4)
25
+
26
+
27
+ def make_generator(color: str, type_id: int) -> ConverterConfig:
28
+ return ConverterConfig(
29
+ type_id=type_id,
30
+ input_resources={f"ore_{color}": 1},
31
+ output_resources={f"battery_{color}": 1},
32
+ cooldown=25,
33
+ )
34
+
35
+
36
+ generator_red = make_generator("red", 5)
37
+ generator_blue = make_generator("blue", 6)
38
+ generator_green = make_generator("green", 7)
39
+
40
+ lasery = ConverterConfig(
41
+ type_id=15,
42
+ input_resources={"battery_red": 1, "ore_red": 2},
43
+ output_resources={"laser": 1},
44
+ cooldown=10,
45
+ )
46
+
47
+ armory = ConverterConfig(
48
+ type_id=16,
49
+ input_resources={"ore_red": 3},
50
+ output_resources={"armor": 1},
51
+ cooldown=10,
52
+ )
53
+
54
+ # Assembler building definitions
55
+ assembler_altar = AssemblerConfig(
56
+ type_id=8,
57
+ recipes=[
58
+ (
59
+ ["Any"],
60
+ RecipeConfig(
61
+ input_resources={"battery_red": 3},
62
+ output_resources={"heart": 1},
63
+ cooldown=10,
64
+ ),
65
+ )
66
+ ],
67
+ )
68
+
69
+
70
+ def make_assembler_mine(color: str, type_id: int) -> AssemblerConfig:
71
+ return AssemblerConfig(
72
+ type_id=type_id,
73
+ recipes=[
74
+ (
75
+ ["Any"],
76
+ RecipeConfig(
77
+ output_resources={f"ore_{color}": 1},
78
+ cooldown=50,
79
+ ),
80
+ )
81
+ ],
82
+ )
83
+
84
+
85
+ assembler_mine_red = make_assembler_mine("red", 2)
86
+ assembler_mine_blue = make_assembler_mine("blue", 3)
87
+ assembler_mine_green = make_assembler_mine("green", 4)
88
+
89
+
90
+ def make_assembler_generator(color: str, type_id: int) -> AssemblerConfig:
91
+ return AssemblerConfig(
92
+ type_id=type_id,
93
+ recipes=[
94
+ (
95
+ ["Any"],
96
+ RecipeConfig(
97
+ input_resources={f"ore_{color}": 1},
98
+ output_resources={f"battery_{color}": 1},
99
+ cooldown=25,
100
+ ),
101
+ )
102
+ ],
103
+ )
104
+
105
+
106
+ assembler_generator_red = make_assembler_generator("red", 5)
107
+ assembler_generator_blue = make_assembler_generator("blue", 6)
108
+ assembler_generator_green = make_assembler_generator("green", 7)
109
+
110
+ assembler_lasery = AssemblerConfig(
111
+ type_id=15,
112
+ recipes=[
113
+ (
114
+ ["Any"],
115
+ RecipeConfig(
116
+ input_resources={"battery_red": 1, "ore_red": 2},
117
+ output_resources={"laser": 1},
118
+ cooldown=10,
119
+ ),
120
+ )
121
+ ],
122
+ )
123
+
124
+ assembler_armory = AssemblerConfig(
125
+ type_id=16,
126
+ recipes=[
127
+ (
128
+ ["Any"],
129
+ RecipeConfig(
130
+ input_resources={"ore_red": 3},
131
+ output_resources={"armor": 1},
132
+ cooldown=10,
133
+ ),
134
+ )
135
+ ],
136
+ )
@@ -0,0 +1,85 @@
1
+ from mettagrid.config.mettagrid_config import ConverterConfig, WallConfig
2
+
3
+ wall = WallConfig(type_id=1)
4
+ block = WallConfig(type_id=14, swappable=True)
5
+
6
+ mine_red = ConverterConfig(
7
+ type_id=2,
8
+ input_resources={},
9
+ output_resources={},
10
+ cooldown=5,
11
+ )
12
+
13
+ mine_blue = ConverterConfig(
14
+ type_id=3,
15
+ input_resources={},
16
+ output_resources={},
17
+ cooldown=5,
18
+ )
19
+
20
+ mine_green = ConverterConfig(
21
+ type_id=4,
22
+ input_resources={},
23
+ output_resources={},
24
+ cooldown=5,
25
+ )
26
+
27
+ generator_red = ConverterConfig(
28
+ type_id=5,
29
+ input_resources={},
30
+ output_resources={},
31
+ cooldown=5,
32
+ )
33
+
34
+ generator_blue = ConverterConfig(
35
+ type_id=6,
36
+ input_resources={},
37
+ output_resources={},
38
+ cooldown=5,
39
+ )
40
+
41
+ generator_green = ConverterConfig(
42
+ type_id=7,
43
+ input_resources={},
44
+ output_resources={},
45
+ cooldown=5,
46
+ )
47
+
48
+ altar = ConverterConfig(
49
+ type_id=8,
50
+ input_resources={},
51
+ output_resources={},
52
+ cooldown=5,
53
+ )
54
+
55
+
56
+ lasery = ConverterConfig(
57
+ type_id=15,
58
+ input_resources={},
59
+ output_resources={},
60
+ cooldown=5,
61
+ )
62
+ armory = ConverterConfig(
63
+ type_id=16,
64
+ input_resources={},
65
+ output_resources={},
66
+ cooldown=5,
67
+ )
68
+ lab = ConverterConfig(
69
+ type_id=17,
70
+ input_resources={},
71
+ output_resources={},
72
+ cooldown=5,
73
+ )
74
+ factory = ConverterConfig(
75
+ type_id=18,
76
+ input_resources={},
77
+ output_resources={},
78
+ cooldown=5,
79
+ )
80
+ temple = ConverterConfig(
81
+ type_id=19,
82
+ input_resources={},
83
+ output_resources={},
84
+ cooldown=5,
85
+ )
@@ -0,0 +1,221 @@
1
+ from typing import Optional
2
+
3
+ import mettagrid.mapgen.scenes.random
4
+ from mettagrid.config.mettagrid_config import (
5
+ ActionConfig,
6
+ ActionsConfig,
7
+ AgentConfig,
8
+ AgentRewards,
9
+ AttackActionConfig,
10
+ GameConfig,
11
+ MettaGridConfig,
12
+ )
13
+ from mettagrid.map_builder.map_builder import MapBuilderConfig
14
+ from mettagrid.map_builder.perimeter_incontext import PerimeterInContextMapBuilder
15
+ from mettagrid.map_builder.random import RandomMapBuilder
16
+ from mettagrid.mapgen.mapgen import MapGen
17
+
18
+ from . import building, empty_converters
19
+
20
+
21
+ def make_arena(
22
+ num_agents: int,
23
+ combat: bool = True,
24
+ map_builder: MapBuilderConfig | None = None, # custom map builder; must match num_agents
25
+ ) -> MettaGridConfig:
26
+ objects = {
27
+ "wall": building.wall,
28
+ "altar": building.altar,
29
+ "mine_red": building.mine_red,
30
+ "generator_red": building.generator_red,
31
+ "lasery": building.lasery,
32
+ "armory": building.armory,
33
+ }
34
+
35
+ actions = ActionsConfig(
36
+ noop=ActionConfig(),
37
+ move=ActionConfig(),
38
+ rotate=ActionConfig(enabled=False), # Disabled for unified movement system
39
+ put_items=ActionConfig(),
40
+ get_items=ActionConfig(),
41
+ attack=AttackActionConfig(
42
+ consumed_resources={
43
+ "laser": 1,
44
+ },
45
+ defense_resources={
46
+ "armor": 1,
47
+ },
48
+ ),
49
+ swap=ActionConfig(enabled=False),
50
+ change_color=ActionConfig(enabled=False),
51
+ )
52
+
53
+ if not combat:
54
+ actions.attack.consumed_resources = {"laser": 100}
55
+
56
+ if map_builder is None:
57
+ map_builder = MapGen.Config(
58
+ num_agents=num_agents,
59
+ width=25,
60
+ height=25,
61
+ border_width=6,
62
+ instance_border_width=0,
63
+ root=mettagrid.mapgen.scenes.random.Random.factory(
64
+ params=mettagrid.mapgen.scenes.random.Random.Params(
65
+ agents=6,
66
+ objects={
67
+ "wall": 10,
68
+ "altar": 5,
69
+ "mine_red": 10,
70
+ "generator_red": 5,
71
+ "lasery": 1,
72
+ "armory": 1,
73
+ },
74
+ ),
75
+ ),
76
+ )
77
+
78
+ return MettaGridConfig(
79
+ label="arena" + (".combat" if combat else ""),
80
+ game=GameConfig(
81
+ num_agents=num_agents,
82
+ actions=actions,
83
+ objects=objects,
84
+ agent=AgentConfig(
85
+ default_resource_limit=50,
86
+ resource_limits={
87
+ "heart": 255,
88
+ },
89
+ rewards=AgentRewards(
90
+ inventory={
91
+ "heart": 1,
92
+ },
93
+ ),
94
+ ),
95
+ map_builder=map_builder,
96
+ ),
97
+ )
98
+
99
+
100
+ def make_navigation(num_agents: int) -> MettaGridConfig:
101
+ altar = empty_converters.altar.model_copy()
102
+ altar.cooldown = 255 # Maximum cooldown
103
+ altar.initial_resource_count = 1
104
+ altar.max_conversions = 0
105
+ altar.input_resources = {}
106
+ altar.output_resources = {"heart": 1}
107
+ cfg = MettaGridConfig(
108
+ game=GameConfig(
109
+ num_agents=num_agents,
110
+ objects={
111
+ "altar": altar,
112
+ "wall": building.wall,
113
+ },
114
+ resource_names=["heart"],
115
+ actions=ActionsConfig(
116
+ move=ActionConfig(),
117
+ rotate=ActionConfig(enabled=False),
118
+ get_items=ActionConfig(),
119
+ ),
120
+ agent=AgentConfig(
121
+ rewards=AgentRewards(
122
+ inventory={
123
+ "heart": 1,
124
+ },
125
+ ),
126
+ ),
127
+ # Always provide a concrete map builder config so tests can set width/height
128
+ map_builder=RandomMapBuilder.Config(agents=num_agents),
129
+ )
130
+ )
131
+ return cfg
132
+
133
+
134
+ def make_navigation_sequence(num_agents: int) -> MettaGridConfig:
135
+ altar = building.altar.model_copy()
136
+ altar.input_resources = {"battery_red": 1}
137
+ altar.cooldown = 15
138
+ mine = building.mine_red.model_copy()
139
+ mine.cooldown = 15
140
+ generator = building.generator_red.model_copy()
141
+ generator.cooldown = 15
142
+ cfg = MettaGridConfig(
143
+ game=GameConfig(
144
+ num_agents=num_agents,
145
+ objects={
146
+ "altar": altar,
147
+ "wall": building.wall,
148
+ "mine_red": mine,
149
+ "generator_red": generator,
150
+ },
151
+ resource_names=["heart", "ore_red", "battery_red"],
152
+ actions=ActionsConfig(
153
+ move=ActionConfig(),
154
+ rotate=ActionConfig(enabled=False),
155
+ get_items=ActionConfig(),
156
+ ),
157
+ agent=AgentConfig(
158
+ rewards=AgentRewards(
159
+ inventory={
160
+ "heart": 1,
161
+ "ore_red": 0.001,
162
+ "battery_red": 0.01,
163
+ },
164
+ ),
165
+ default_resource_limit=1,
166
+ resource_limits={
167
+ "heart": 100,
168
+ },
169
+ ),
170
+ # Always provide a concrete map builder config so tests can set width/height
171
+ map_builder=RandomMapBuilder.Config(agents=num_agents),
172
+ )
173
+ )
174
+ return cfg
175
+
176
+
177
+ def make_icl_resource_chain(
178
+ num_agents: int,
179
+ max_steps,
180
+ game_objects: dict,
181
+ map_builder_objects: dict,
182
+ width: int = 6,
183
+ height: int = 6,
184
+ obstacle_type: Optional[str] = None,
185
+ density: Optional[str] = None,
186
+ ) -> MettaGridConfig:
187
+ game_objects["wall"] = empty_converters.wall
188
+ cfg = MettaGridConfig(
189
+ game=GameConfig(
190
+ max_steps=max_steps,
191
+ num_agents=num_agents,
192
+ objects=game_objects,
193
+ map_builder=MapGen.Config(
194
+ instances=num_agents,
195
+ instance_map=PerimeterInContextMapBuilder.Config(
196
+ agents=1,
197
+ width=width,
198
+ height=height,
199
+ objects=map_builder_objects,
200
+ obstacle_type=obstacle_type,
201
+ density=density,
202
+ ),
203
+ ),
204
+ actions=ActionsConfig(
205
+ move=ActionConfig(),
206
+ rotate=ActionConfig(enabled=False), # Disabled for unified movement system
207
+ get_items=ActionConfig(),
208
+ put_items=ActionConfig(),
209
+ ),
210
+ agent=AgentConfig(
211
+ rewards=AgentRewards(
212
+ inventory={
213
+ "heart": 1,
214
+ },
215
+ ),
216
+ default_resource_limit=1,
217
+ resource_limits={"heart": 15},
218
+ ),
219
+ )
220
+ )
221
+ return cfg
@@ -0,0 +1,35 @@
1
+ """Configuration module for mettagrid."""
2
+
3
+ from .config import Config
4
+ from .mettagrid_c_config import from_mettagrid_config
5
+ from .mettagrid_config import (
6
+ ActionConfig,
7
+ ActionsConfig,
8
+ AgentConfig,
9
+ AgentRewards,
10
+ AttackActionConfig,
11
+ ChangeGlyphActionConfig,
12
+ ConverterConfig,
13
+ GameConfig,
14
+ GlobalObsConfig,
15
+ MettaGridConfig,
16
+ StatsRewards,
17
+ WallConfig,
18
+ )
19
+
20
+ __all__ = [
21
+ "Config",
22
+ "from_mettagrid_config",
23
+ "MettaGridConfig",
24
+ "ActionConfig",
25
+ "ActionsConfig",
26
+ "AgentConfig",
27
+ "AgentRewards",
28
+ "AttackActionConfig",
29
+ "ChangeGlyphActionConfig",
30
+ "ConverterConfig",
31
+ "GameConfig",
32
+ "GlobalObsConfig",
33
+ "StatsRewards",
34
+ "WallConfig",
35
+ ]
@@ -0,0 +1,106 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, NoReturn, Self, Union, get_args, get_origin
4
+
5
+ from pydantic import BaseModel, ConfigDict, TypeAdapter
6
+
7
+
8
+ class Config(BaseModel):
9
+ """
10
+ Common extension of Pydantic's BaseModel that:
11
+ - sets `extra="forbid"` by default
12
+ - adds `override` and `update` methods for overriding values based on `path.to.value` keys
13
+ """
14
+
15
+ model_config = ConfigDict(extra="forbid")
16
+
17
+ def _auto_initialize_field(self, parent_obj: "Config", field_name: str) -> "Config | None":
18
+ """Auto-initialize a None Config field if possible."""
19
+ field = type(parent_obj).model_fields.get(field_name)
20
+ if not field:
21
+ return None
22
+
23
+ field_type = self._unwrap_optional(field.annotation)
24
+ if not (isinstance(field_type, type) and issubclass(field_type, Config)):
25
+ return None
26
+
27
+ try:
28
+ new_instance = field_type()
29
+ setattr(parent_obj, field_name, new_instance)
30
+ return new_instance
31
+ except (TypeError, ValueError):
32
+ return None
33
+
34
+ def _unwrap_optional(self, field_type):
35
+ """Unwrap Optional[T] → T if applicable, else return original type."""
36
+ if get_origin(field_type) is Union:
37
+ non_none_types = [arg for arg in get_args(field_type) if arg is not type(None)]
38
+ return non_none_types[0] if len(non_none_types) == 1 else field_type
39
+ return field_type
40
+
41
+ def override(self, key: str, value: Any) -> Self:
42
+ """Override a value in the config."""
43
+ key_path = key.split(".")
44
+
45
+ def fail(error: str) -> NoReturn:
46
+ raise ValueError(
47
+ f"Override failed. Full config:\n {self.model_dump_json(indent=2)}\nOverride {key} failed: {error}"
48
+ )
49
+
50
+ inner_cfg: Config | dict[str, Any] = self
51
+ traversed_path: list[str] = []
52
+ for key_part in key_path[:-1]:
53
+ if isinstance(inner_cfg, dict):
54
+ if key_part not in inner_cfg:
55
+ fail(f"key {key} not found")
56
+ inner_cfg = inner_cfg[key_part]
57
+ traversed_path.append(key_part)
58
+ continue
59
+
60
+ if not hasattr(inner_cfg, key_part):
61
+ failed_path = ".".join(traversed_path + [key_part])
62
+ fail(f"key {failed_path} not found")
63
+
64
+ next_inner_cfg = getattr(inner_cfg, key_part)
65
+ if next_inner_cfg is None:
66
+ # Auto-initialize None Config fields
67
+ next_inner_cfg = self._auto_initialize_field(inner_cfg, key_part)
68
+ if next_inner_cfg is None:
69
+ failed_path = ".".join(traversed_path + [key_part])
70
+ fail(f"Cannot auto-initialize None field {failed_path}")
71
+
72
+ if not isinstance(next_inner_cfg, (Config, dict)):
73
+ failed_path = ".".join(traversed_path + [key_part])
74
+ fail(f"key {failed_path} is not a Config object")
75
+
76
+ inner_cfg = next_inner_cfg
77
+ traversed_path.append(key_part)
78
+
79
+ # We allow dicts to get new keys, but not Configs. This is because we want to allow overrides like
80
+ # env_cfg.game.agent.rewards.inventory.ore_red = 0.1
81
+ # without requiring that "ore_red" was already in the inventory dict. Note that allowing overrides / updates
82
+ # to dicts like this leads to an obnoxious inconsistency in the way dicts are updated via overrides
83
+ # (foo.bar.baz = 1) vs how they're set in Python (foo.bar["baz"] = 1).
84
+ if isinstance(inner_cfg, Config):
85
+ if not hasattr(inner_cfg, key_path[-1]):
86
+ fail(f"key {key} not found")
87
+
88
+ if isinstance(inner_cfg, dict):
89
+ inner_cfg[key_path[-1]] = value
90
+ return self
91
+
92
+ cls = type(inner_cfg)
93
+ field = cls.model_fields.get(key_path[-1])
94
+ if field is None:
95
+ fail(f"key {key} is not a valid field")
96
+
97
+ value = TypeAdapter(field.annotation).validate_python(value)
98
+ setattr(inner_cfg, key_path[-1], value)
99
+
100
+ return self
101
+
102
+ def update(self, updates: dict[str, Any]) -> Self:
103
+ """Applies multiple overrides to the config."""
104
+ for key, value in updates.items():
105
+ self.override(key, value)
106
+ return self