kaggle-environments 1.15.3__py2.py3-none-any.whl → 1.16.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +1 -1
- kaggle_environments/envs/chess/chess.json +4 -4
- kaggle_environments/envs/chess/chess.py +209 -51
- kaggle_environments/envs/chess/test_chess.py +43 -1
- kaggle_environments/envs/connectx/connectx.ipynb +3183 -0
- kaggle_environments/envs/football/football.ipynb +75 -0
- kaggle_environments/envs/halite/halite.ipynb +44736 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +112 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
- kaggle_environments/envs/lux_ai_2021/README.md +3 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
- kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
- kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
- kaggle_environments/envs/lux_ai_s2/.gitignore +1 -0
- kaggle_environments/envs/lux_ai_s2/README.md +21 -0
- kaggle_environments/envs/lux_ai_s2/luxai_s2/.DS_Store +0 -0
- kaggle_environments/envs/lux_ai_s2/luxai_s2/map_generator/.DS_Store +0 -0
- kaggle_environments/envs/lux_ai_s3/README.md +21 -0
- kaggle_environments/envs/lux_ai_s3/agents.py +4 -0
- kaggle_environments/envs/lux_ai_s3/index.html +42 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +138 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +924 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +13 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +140 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +270 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +30 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +399 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +187 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +71 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +27 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +53 -0
- kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
- kaggle_environments/envs/tictactoe/tictactoe.ipynb +1393 -0
- {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.0.dist-info}/METADATA +2 -2
- {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.0.dist-info}/RECORD +68 -10
- {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.0.dist-info}/WHEEL +1 -1
- {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.0.dist-info}/LICENSE +0 -0
- {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.0.dist-info}/entry_points.txt +0 -0
- {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import chex
|
|
3
|
+
import flax
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
|
+
from flax import struct
|
|
8
|
+
|
|
9
|
+
from luxai_s3.params import MAP_TYPES, EnvParams
|
|
10
|
+
from luxai_s3.utils import to_numpy
|
|
11
|
+
EMPTY_TILE = 0
|
|
12
|
+
NEBULA_TILE = 1
|
|
13
|
+
ASTEROID_TILE = 2
|
|
14
|
+
|
|
15
|
+
ENERGY_NODE_FNS = [
|
|
16
|
+
lambda d, x, y, z: jnp.sin(d * x + y) * z, lambda d, x, y, z: (x / (d + 1) + y) * z
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
@struct.dataclass
|
|
20
|
+
class UnitState:
|
|
21
|
+
position: chex.Array
|
|
22
|
+
"""Position of the unit with shape (2) for x, y"""
|
|
23
|
+
energy: int
|
|
24
|
+
"""Energy of the unit"""
|
|
25
|
+
|
|
26
|
+
@struct.dataclass
|
|
27
|
+
class MapTile:
|
|
28
|
+
energy: int
|
|
29
|
+
"""Energy of the tile, generated via energy_nodes and energy_node_fns"""
|
|
30
|
+
tile_type: int
|
|
31
|
+
"""Type of the tile"""
|
|
32
|
+
|
|
33
|
+
@struct.dataclass
|
|
34
|
+
class EnvState:
|
|
35
|
+
units: UnitState
|
|
36
|
+
"""Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
|
|
37
|
+
|
|
38
|
+
3 features are for position (x, y), and energy
|
|
39
|
+
"""
|
|
40
|
+
units_mask: chex.Array
|
|
41
|
+
"""Mask of units in the environment with shape (T, N) for T teams, N max units"""
|
|
42
|
+
energy_nodes: chex.Array
|
|
43
|
+
"""Energy nodes in the environment with shape (N, 2) for N max energy nodes, and 2 features.
|
|
44
|
+
|
|
45
|
+
2 features are for position (x, y)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
energy_node_fns: chex.Array
|
|
49
|
+
"""Energy node functions for computing the energy field of the map. They describe the function with a sequence of numbers
|
|
50
|
+
|
|
51
|
+
The first number is the function used. The subsequent numbers parameterize the function. The function is applied to distance of map tile to energy node and the function parameters.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
# energy_field: chex.Array
|
|
55
|
+
# """Energy field in the environment with shape (H, W) for H height, W width. This is generated from other state"""
|
|
56
|
+
|
|
57
|
+
energy_nodes_mask: chex.Array
|
|
58
|
+
"""Mask of energy nodes in the environment with shape (N) for N max energy nodes"""
|
|
59
|
+
relic_nodes: chex.Array
|
|
60
|
+
"""Relic nodes in the environment with shape (N, 2) for N max relic nodes, and 2 features.
|
|
61
|
+
|
|
62
|
+
2 features are for position (x, y)
|
|
63
|
+
"""
|
|
64
|
+
relic_node_configs: chex.Array
|
|
65
|
+
"""Relic node configs in the environment with shape (N, K, K) for N max relic nodes and a KxK relic configuration"""
|
|
66
|
+
relic_nodes_mask: chex.Array
|
|
67
|
+
"""Mask of relic nodes in the environment with shape (N, ) for N max relic nodes"""
|
|
68
|
+
relic_nodes_map_weights: chex.Array
|
|
69
|
+
"""Map of relic nodes in the environment with shape (H, W) for H height, W width. True if a relic node is present, False otherwise. This is generated from other state"""
|
|
70
|
+
|
|
71
|
+
map_features: MapTile
|
|
72
|
+
"""Map features in the environment with shape (W, H, 2) for W width, H height
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
sensor_mask: chex.Array
|
|
76
|
+
"""Sensor mask in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
|
|
77
|
+
|
|
78
|
+
vision_power_map: chex.Array
|
|
79
|
+
"""Vision power map in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
|
|
80
|
+
|
|
81
|
+
team_points: chex.Array
|
|
82
|
+
"""Team points in the environment with shape (T) for T teams"""
|
|
83
|
+
team_wins: chex.Array
|
|
84
|
+
"""Team wins in the environment with shape (T) for T teams"""
|
|
85
|
+
|
|
86
|
+
steps: int = 0
|
|
87
|
+
"""steps taken in the environment"""
|
|
88
|
+
match_steps: int = 0
|
|
89
|
+
"""steps taken in the current match"""
|
|
90
|
+
|
|
91
|
+
@struct.dataclass
|
|
92
|
+
class EnvObs:
|
|
93
|
+
"""Partial observation of environment"""
|
|
94
|
+
units: UnitState
|
|
95
|
+
"""Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
|
|
96
|
+
|
|
97
|
+
3 features are for position (x, y), and energy
|
|
98
|
+
"""
|
|
99
|
+
units_mask: chex.Array
|
|
100
|
+
"""Mask of units in the environment with shape (T, N) for T teams, N max units"""
|
|
101
|
+
|
|
102
|
+
sensor_mask: chex.Array
|
|
103
|
+
|
|
104
|
+
map_features: MapTile
|
|
105
|
+
"""Map features in the environment with shape (W, H, 2) for W width, H height
|
|
106
|
+
"""
|
|
107
|
+
relic_nodes: chex.Array
|
|
108
|
+
"""Position of all relic nodes with shape (N, 2) for N max relic nodes and 2 features for position (x, y). Number is -1 if not visible"""
|
|
109
|
+
relic_nodes_mask: chex.Array
|
|
110
|
+
"""Mask of all relic nodes with shape (N) for N max relic nodes"""
|
|
111
|
+
team_points: chex.Array
|
|
112
|
+
"""Team points in the environment with shape (T) for T teams"""
|
|
113
|
+
team_wins: chex.Array
|
|
114
|
+
"""Team wins in the environment with shape (T) for T teams"""
|
|
115
|
+
steps: int = 0
|
|
116
|
+
"""steps taken in the environment"""
|
|
117
|
+
match_steps: int = 0
|
|
118
|
+
"""steps taken in the current match"""
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def serialize_env_states(env_states: list[EnvState]):
|
|
123
|
+
def serialize_array(root: EnvState, arr, key_path: str = ""):
|
|
124
|
+
if key_path in ["sensor_mask", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]:
|
|
125
|
+
return None
|
|
126
|
+
if key_path == "relic_nodes":
|
|
127
|
+
return root.relic_nodes[root.relic_nodes_mask].tolist()
|
|
128
|
+
if key_path == "relic_node_configs":
|
|
129
|
+
return root.relic_node_configs[root.relic_nodes_mask].tolist()
|
|
130
|
+
if key_path == "energy_nodes":
|
|
131
|
+
return root.energy_nodes[root.energy_nodes_mask].tolist()
|
|
132
|
+
if isinstance(arr, jnp.ndarray):
|
|
133
|
+
return arr.tolist()
|
|
134
|
+
elif isinstance(arr, dict):
|
|
135
|
+
ret = dict()
|
|
136
|
+
for k, v in arr.items():
|
|
137
|
+
new_key = key_path + "/" + k if key_path else k
|
|
138
|
+
new_val = serialize_array(root, v, new_key)
|
|
139
|
+
if new_val is not None:
|
|
140
|
+
ret[k] = new_val
|
|
141
|
+
return ret
|
|
142
|
+
return arr
|
|
143
|
+
steps = []
|
|
144
|
+
for state in env_states:
|
|
145
|
+
state_dict = flax.serialization.to_state_dict(state)
|
|
146
|
+
steps.append(serialize_array(state, state_dict))
|
|
147
|
+
|
|
148
|
+
return steps
|
|
149
|
+
|
|
150
|
+
def serialize_env_actions(env_actions: list):
|
|
151
|
+
def serialize_array(arr, key_path: str = ""):
|
|
152
|
+
if isinstance(arr, np.ndarray):
|
|
153
|
+
return arr.tolist()
|
|
154
|
+
elif isinstance(arr, jnp.ndarray):
|
|
155
|
+
return arr.tolist()
|
|
156
|
+
elif isinstance(arr, dict):
|
|
157
|
+
ret = dict()
|
|
158
|
+
for k, v in arr.items():
|
|
159
|
+
new_key = key_path + "/" + k if key_path else k
|
|
160
|
+
new_val = serialize_array(v, new_key)
|
|
161
|
+
if new_val is not None:
|
|
162
|
+
ret[k] = new_val
|
|
163
|
+
return ret
|
|
164
|
+
|
|
165
|
+
return arr
|
|
166
|
+
steps = []
|
|
167
|
+
for state in env_actions:
|
|
168
|
+
state = flax.serialization.to_state_dict(state)
|
|
169
|
+
steps.append(serialize_array(state))
|
|
170
|
+
|
|
171
|
+
return steps
|
|
172
|
+
|
|
173
|
+
def state_to_flat_obs(state: EnvState) -> chex.Array:
|
|
174
|
+
pass
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def flat_obs_to_state(flat_obs: chex.Array) -> EnvState:
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9))
|
|
181
|
+
def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_teams: int, map_type: int, map_width: int, map_height: int, max_energy_nodes: int, max_relic_nodes: int, relic_config_size: int) -> EnvState:
|
|
182
|
+
generated = gen_map(key, env_params, map_type, map_width, map_height, max_energy_nodes, max_relic_nodes, relic_config_size)
|
|
183
|
+
relic_nodes_map_weights = jnp.zeros(
|
|
184
|
+
shape=(map_width, map_height), dtype=jnp.int16
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# TODO (this could be optimized better)
|
|
188
|
+
def update_relic_node(relic_nodes_map_weights, relic_data):
|
|
189
|
+
relic_node, relic_node_config, mask = relic_data
|
|
190
|
+
start_y = relic_node[1] - relic_config_size // 2
|
|
191
|
+
start_x = relic_node[0] - relic_config_size // 2
|
|
192
|
+
for dy in range(relic_config_size):
|
|
193
|
+
for dx in range(relic_config_size):
|
|
194
|
+
y, x = start_y + dy, start_x + dx
|
|
195
|
+
valid_pos = jnp.logical_and(
|
|
196
|
+
jnp.logical_and(y >= 0, x >= 0),
|
|
197
|
+
jnp.logical_and(y < map_height, x < map_width),
|
|
198
|
+
)
|
|
199
|
+
relic_nodes_map_weights = jnp.where(
|
|
200
|
+
valid_pos & mask,
|
|
201
|
+
relic_nodes_map_weights.at[x, y].add(relic_node_config[dx, dy].astype(jnp.int16)),
|
|
202
|
+
relic_nodes_map_weights,
|
|
203
|
+
)
|
|
204
|
+
return relic_nodes_map_weights, None
|
|
205
|
+
|
|
206
|
+
# this is really slow...
|
|
207
|
+
relic_nodes_map_weights, _ = jax.lax.scan(
|
|
208
|
+
update_relic_node,
|
|
209
|
+
relic_nodes_map_weights,
|
|
210
|
+
(
|
|
211
|
+
generated["relic_nodes"],
|
|
212
|
+
generated["relic_node_configs"],
|
|
213
|
+
generated["relic_nodes_mask"],
|
|
214
|
+
),
|
|
215
|
+
)
|
|
216
|
+
state = EnvState(
|
|
217
|
+
units=UnitState(position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16), energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16)),
|
|
218
|
+
units_mask=jnp.zeros(
|
|
219
|
+
shape=(num_teams, max_units), dtype=jnp.bool
|
|
220
|
+
),
|
|
221
|
+
team_points=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
|
|
222
|
+
team_wins=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
|
|
223
|
+
energy_nodes=generated["energy_nodes"],
|
|
224
|
+
energy_node_fns=generated["energy_node_fns"],
|
|
225
|
+
energy_nodes_mask=generated["energy_nodes_mask"],
|
|
226
|
+
# energy_field=jnp.zeros(shape=(params.map_height, params.map_width), dtype=jnp.int16),
|
|
227
|
+
relic_nodes=generated["relic_nodes"],
|
|
228
|
+
relic_nodes_mask=generated["relic_nodes_mask"],
|
|
229
|
+
relic_node_configs=generated["relic_node_configs"],
|
|
230
|
+
relic_nodes_map_weights=relic_nodes_map_weights,
|
|
231
|
+
sensor_mask=jnp.zeros(
|
|
232
|
+
shape=(num_teams, map_height, map_width),
|
|
233
|
+
dtype=jnp.bool,
|
|
234
|
+
),
|
|
235
|
+
vision_power_map=jnp.zeros(shape=(num_teams, map_height, map_width), dtype=jnp.int16),
|
|
236
|
+
map_features=generated["map_features"],
|
|
237
|
+
)
|
|
238
|
+
return state
|
|
239
|
+
|
|
240
|
+
@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7))
|
|
241
|
+
def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int, map_width: int, max_energy_nodes: int, max_relic_nodes: int, relic_config_size: int) -> chex.Array:
|
|
242
|
+
map_features = MapTile(energy=jnp.zeros(
|
|
243
|
+
shape=(map_height, map_width), dtype=jnp.int16
|
|
244
|
+
), tile_type=jnp.zeros(
|
|
245
|
+
shape=(map_height, map_width), dtype=jnp.int16
|
|
246
|
+
))
|
|
247
|
+
energy_nodes = jnp.zeros(shape=(max_energy_nodes, 2), dtype=jnp.int16)
|
|
248
|
+
energy_nodes_mask = jnp.zeros(shape=(max_energy_nodes), dtype=jnp.bool)
|
|
249
|
+
relic_nodes = jnp.zeros(shape=(max_relic_nodes, 2), dtype=jnp.int16)
|
|
250
|
+
relic_nodes_mask = jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool)
|
|
251
|
+
|
|
252
|
+
if MAP_TYPES[map_type] == "random":
|
|
253
|
+
|
|
254
|
+
### Generate nebula tiles ###
|
|
255
|
+
key, subkey = jax.random.split(key)
|
|
256
|
+
perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
|
|
257
|
+
noise = jnp.where(perlin_noise > 0.5, 1, 0)
|
|
258
|
+
# mirror along diagonal
|
|
259
|
+
noise = noise | noise.T
|
|
260
|
+
noise = noise[::-1, ::1]
|
|
261
|
+
map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0))
|
|
262
|
+
|
|
263
|
+
### Generate asteroid tiles ###
|
|
264
|
+
key, subkey = jax.random.split(key)
|
|
265
|
+
perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (8, 8))
|
|
266
|
+
noise = jnp.where(perlin_noise < -0.5, 1, 0)
|
|
267
|
+
# mirror along diagonal
|
|
268
|
+
noise = noise | noise.T
|
|
269
|
+
noise = noise[::-1, ::1]
|
|
270
|
+
map_features = map_features.replace(tile_type=jnp.place(map_features.tile_type, noise, ASTEROID_TILE, inplace=False))
|
|
271
|
+
|
|
272
|
+
### Generate relic nodes ###
|
|
273
|
+
key, subkey = jax.random.split(key)
|
|
274
|
+
noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
|
|
275
|
+
# Find the positions of the highest noise values
|
|
276
|
+
flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2:] # Get indices of two highest values
|
|
277
|
+
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
|
|
278
|
+
|
|
279
|
+
# relic nodes have a fixed density of 25% nearby tiles can yield points
|
|
280
|
+
relic_node_configs = (
|
|
281
|
+
jax.random.randint(
|
|
282
|
+
key,
|
|
283
|
+
shape=(
|
|
284
|
+
max_relic_nodes,
|
|
285
|
+
relic_config_size,
|
|
286
|
+
relic_config_size,
|
|
287
|
+
),
|
|
288
|
+
minval=0,
|
|
289
|
+
maxval=10,
|
|
290
|
+
).astype(jnp.float32)
|
|
291
|
+
>= 7.5
|
|
292
|
+
)
|
|
293
|
+
highest_positions = highest_positions.astype(jnp.int16)
|
|
294
|
+
relic_nodes_mask = relic_nodes_mask.at[0].set(True)
|
|
295
|
+
relic_nodes_mask = relic_nodes_mask.at[1].set(True)
|
|
296
|
+
mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
|
|
297
|
+
relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
|
|
298
|
+
|
|
299
|
+
key, subkey = jax.random.split(key)
|
|
300
|
+
relic_nodes_mask_half = jax.random.randint(key, (max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
|
|
301
|
+
relic_nodes_mask_half = relic_nodes_mask_half.at[0].set(True)
|
|
302
|
+
relic_nodes_mask = relic_nodes_mask.at[:max_relic_nodes // 2].set(relic_nodes_mask_half)
|
|
303
|
+
relic_nodes_mask = relic_nodes_mask.at[max_relic_nodes // 2:].set(relic_nodes_mask_half)
|
|
304
|
+
# import ipdb;ipdb.set_trace()
|
|
305
|
+
relic_node_configs = relic_node_configs.at[max_relic_nodes // 2:].set(relic_node_configs[:max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1])
|
|
306
|
+
|
|
307
|
+
### Generate energy nodes ###
|
|
308
|
+
key, subkey = jax.random.split(key)
|
|
309
|
+
noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
|
|
310
|
+
# Find the positions of the highest noise values
|
|
311
|
+
flat_indices = jnp.argsort(noise.ravel())[-max_energy_nodes // 2:] # Get indices of highest values
|
|
312
|
+
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)).astype(jnp.int16)
|
|
313
|
+
mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
|
|
314
|
+
energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
|
|
315
|
+
key, subkey = jax.random.split(key)
|
|
316
|
+
energy_nodes_mask_half = jax.random.randint(key, (max_energy_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
|
|
317
|
+
energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True)
|
|
318
|
+
energy_nodes_mask = energy_nodes_mask.at[:max_energy_nodes // 2].set(energy_nodes_mask_half)
|
|
319
|
+
energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2:].set(energy_nodes_mask_half)
|
|
320
|
+
|
|
321
|
+
# TODO (stao): provide more randomization options for energy node functions.
|
|
322
|
+
energy_node_fns = jnp.array(
|
|
323
|
+
[
|
|
324
|
+
[0, 1.2, 1, 4],
|
|
325
|
+
[0, 0, 0, 0],
|
|
326
|
+
[0, 0, 0, 0],
|
|
327
|
+
# [1, 4, 0, 2],
|
|
328
|
+
[0, 1.2, 1, 4],
|
|
329
|
+
[0, 0, 0, 0],
|
|
330
|
+
[0, 0, 0, 0],
|
|
331
|
+
# [1, 4, 0, 0]
|
|
332
|
+
]
|
|
333
|
+
)
|
|
334
|
+
# import ipdb; ipdb.set_trace()
|
|
335
|
+
# energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
return dict(
|
|
339
|
+
map_features=map_features,
|
|
340
|
+
energy_nodes=energy_nodes,
|
|
341
|
+
energy_node_fns=energy_node_fns,
|
|
342
|
+
relic_nodes=relic_nodes,
|
|
343
|
+
energy_nodes_mask=energy_nodes_mask,
|
|
344
|
+
relic_nodes_mask=relic_nodes_mask,
|
|
345
|
+
relic_node_configs=relic_node_configs,
|
|
346
|
+
)
|
|
347
|
+
def interpolant(t):
|
|
348
|
+
return t*t*t*(t*(t*6 - 15) + 10)
|
|
349
|
+
|
|
350
|
+
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
|
351
|
+
def generate_perlin_noise_2d(
|
|
352
|
+
key, shape, res, tileable=(False, False), interpolant=interpolant
|
|
353
|
+
):
|
|
354
|
+
"""Generate a 2D numpy array of perlin noise.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
shape: The shape of the generated array (tuple of two ints).
|
|
358
|
+
This must be a multple of res.
|
|
359
|
+
res: The number of periods of noise to generate along each
|
|
360
|
+
axis (tuple of two ints). Note shape must be a multiple of
|
|
361
|
+
res.
|
|
362
|
+
tileable: If the noise should be tileable along each axis
|
|
363
|
+
(tuple of two bools). Defaults to (False, False).
|
|
364
|
+
interpolant: The interpolation function, defaults to
|
|
365
|
+
t*t*t*(t*(t*6 - 15) + 10).
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
A numpy array of shape shape with the generated noise.
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
ValueError: If shape is not a multiple of res.
|
|
372
|
+
"""
|
|
373
|
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
|
374
|
+
d = (shape[0] // res[0], shape[1] // res[1])
|
|
375
|
+
grid = jnp.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]]\
|
|
376
|
+
.transpose(1, 2, 0) % 1
|
|
377
|
+
# Gradients
|
|
378
|
+
angles = 2*jnp.pi*jax.random.uniform(key, (res[0]+1, res[1]+1))
|
|
379
|
+
gradients = jnp.dstack((jnp.cos(angles), jnp.sin(angles)))
|
|
380
|
+
if tileable[0]:
|
|
381
|
+
gradients[-1,:] = gradients[0,:]
|
|
382
|
+
if tileable[1]:
|
|
383
|
+
gradients[:,-1] = gradients[:,0]
|
|
384
|
+
gradients = gradients.repeat(d[0], 0).repeat(d[1], 1)
|
|
385
|
+
g00 = gradients[ :-d[0], :-d[1]]
|
|
386
|
+
g10 = gradients[d[0]: , :-d[1]]
|
|
387
|
+
g01 = gradients[ :-d[0],d[1]: ]
|
|
388
|
+
g11 = gradients[d[0]: ,d[1]: ]
|
|
389
|
+
|
|
390
|
+
# Ramps
|
|
391
|
+
n00 = jnp.sum(jnp.dstack((grid[:,:,0] , grid[:,:,1] )) * g00, 2)
|
|
392
|
+
n10 = jnp.sum(jnp.dstack((grid[:,:,0]-1, grid[:,:,1] )) * g10, 2)
|
|
393
|
+
n01 = jnp.sum(jnp.dstack((grid[:,:,0] , grid[:,:,1]-1)) * g01, 2)
|
|
394
|
+
n11 = jnp.sum(jnp.dstack((grid[:,:,0]-1, grid[:,:,1]-1)) * g11, 2)
|
|
395
|
+
# Interpolation
|
|
396
|
+
t = interpolant(grid)
|
|
397
|
+
n0 = n00*(1-t[:,:,0]) + t[:,:,0]*n10
|
|
398
|
+
n1 = n01*(1-t[:,:,0]) + t[:,:,0]*n11
|
|
399
|
+
return jnp.sqrt(2)*((1-t[:,:,1])*n0 + t[:,:,1]*n1)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# TODO (stao): Add lux ai s3 env to gymnax api wrapper, which is the old gym api
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, SupportsFloat
|
|
5
|
+
import flax
|
|
6
|
+
import flax.serialization
|
|
7
|
+
import gymnasium as gym
|
|
8
|
+
import gymnax
|
|
9
|
+
import gymnax.environments.spaces
|
|
10
|
+
import jax
|
|
11
|
+
import numpy as np
|
|
12
|
+
import dataclasses
|
|
13
|
+
from luxai_s3.env import LuxAIS3Env
|
|
14
|
+
from luxai_s3.params import EnvParams, env_params_ranges
|
|
15
|
+
from luxai_s3.state import serialize_env_actions, serialize_env_states
|
|
16
|
+
from luxai_s3.utils import to_numpy
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LuxAIS3GymEnv(gym.Env):
|
|
20
|
+
def __init__(self, numpy_output: bool = False):
|
|
21
|
+
self.numpy_output = numpy_output
|
|
22
|
+
self.rng_key = jax.random.key(0)
|
|
23
|
+
self.jax_env = LuxAIS3Env(auto_reset=False)
|
|
24
|
+
self.env_params: EnvParams = EnvParams()
|
|
25
|
+
|
|
26
|
+
# auto run compiling steps here:
|
|
27
|
+
# print("Running compilation steps")
|
|
28
|
+
key = jax.random.key(0)
|
|
29
|
+
# Reset the environment
|
|
30
|
+
dummy_env_params = EnvParams(map_type=1)
|
|
31
|
+
key, reset_key = jax.random.split(key)
|
|
32
|
+
obs, state = self.jax_env.reset(reset_key, params=dummy_env_params)
|
|
33
|
+
# Take a random action
|
|
34
|
+
key, subkey = jax.random.split(key)
|
|
35
|
+
action = self.jax_env.action_space(dummy_env_params).sample(subkey)
|
|
36
|
+
# Step the environment and compile. Not sure why 2 steps? are needed
|
|
37
|
+
for _ in range(2):
|
|
38
|
+
key, subkey = jax.random.split(key)
|
|
39
|
+
obs, state, reward, terminated, truncated, info = self.jax_env.step(
|
|
40
|
+
subkey, state, action, params=dummy_env_params
|
|
41
|
+
)
|
|
42
|
+
# print("Finish compilation steps")
|
|
43
|
+
low = np.zeros((self.env_params.max_units, 3))
|
|
44
|
+
low[:, 1:] = -self.env_params.unit_sap_range
|
|
45
|
+
high = np.ones((self.env_params.max_units, 3)) * 6
|
|
46
|
+
high[:, 1:] = self.env_params.unit_sap_range
|
|
47
|
+
self.action_space = gym.spaces.Dict(
|
|
48
|
+
dict(
|
|
49
|
+
player_0=gym.spaces.Box(low=low, high=high, dtype=np.int16),
|
|
50
|
+
player_1=gym.spaces.Box(low=low, high=high, dtype=np.int16),
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def render(self):
|
|
55
|
+
self.jax_env.render(self.state, self.env_params)
|
|
56
|
+
|
|
57
|
+
def reset(
|
|
58
|
+
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
59
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
60
|
+
if seed is not None:
|
|
61
|
+
self.rng_key = jax.random.key(seed)
|
|
62
|
+
self.rng_key, reset_key = jax.random.split(self.rng_key)
|
|
63
|
+
# generate random game parameters
|
|
64
|
+
# TODO (stao): check why this keeps recompiling when marking structs as static args
|
|
65
|
+
randomized_game_params = dict()
|
|
66
|
+
for k, v in env_params_ranges.items():
|
|
67
|
+
self.rng_key, subkey = jax.random.split(self.rng_key)
|
|
68
|
+
randomized_game_params[k] = jax.random.choice(
|
|
69
|
+
subkey, jax.numpy.array(v)
|
|
70
|
+
).item()
|
|
71
|
+
params = EnvParams(**randomized_game_params)
|
|
72
|
+
if options is not None and "params" in options:
|
|
73
|
+
params = options["params"]
|
|
74
|
+
|
|
75
|
+
self.env_params = params
|
|
76
|
+
obs, self.state = self.jax_env.reset(reset_key, params=params)
|
|
77
|
+
if self.numpy_output:
|
|
78
|
+
obs = to_numpy(flax.serialization.to_state_dict(obs))
|
|
79
|
+
|
|
80
|
+
# only keep the following game parameters available to the agent
|
|
81
|
+
params_dict = dataclasses.asdict(params)
|
|
82
|
+
params_dict_kept = dict()
|
|
83
|
+
for k in [
|
|
84
|
+
"max_units",
|
|
85
|
+
"match_count_per_episode",
|
|
86
|
+
"max_steps_in_match",
|
|
87
|
+
"map_height",
|
|
88
|
+
"map_width",
|
|
89
|
+
"num_teams",
|
|
90
|
+
"unit_move_cost",
|
|
91
|
+
"unit_sap_cost",
|
|
92
|
+
"unit_sap_range",
|
|
93
|
+
"unit_sensor_range",
|
|
94
|
+
]:
|
|
95
|
+
params_dict_kept[k] = params_dict[k]
|
|
96
|
+
return obs, dict(
|
|
97
|
+
params=params_dict_kept, full_params=params_dict, state=self.state
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def step(
|
|
101
|
+
self, action: Any
|
|
102
|
+
) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
103
|
+
self.rng_key, step_key = jax.random.split(self.rng_key)
|
|
104
|
+
obs, self.state, reward, terminated, truncated, info = self.jax_env.step(
|
|
105
|
+
step_key, self.state, action, self.env_params
|
|
106
|
+
)
|
|
107
|
+
if self.numpy_output:
|
|
108
|
+
obs = to_numpy(flax.serialization.to_state_dict(obs))
|
|
109
|
+
reward = to_numpy(reward)
|
|
110
|
+
terminated = to_numpy(terminated)
|
|
111
|
+
truncated = to_numpy(truncated)
|
|
112
|
+
# info = to_numpy(flax.serialization.to_state_dict(info))
|
|
113
|
+
return obs, reward, terminated, truncated, info
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# TODO: vectorized gym wrapper
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RecordEpisode(gym.Wrapper):
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
env: LuxAIS3GymEnv,
|
|
123
|
+
save_dir: str = None,
|
|
124
|
+
save_on_close: bool = True,
|
|
125
|
+
save_on_reset: bool = True,
|
|
126
|
+
):
|
|
127
|
+
super().__init__(env)
|
|
128
|
+
self.episode = dict(states=[], actions=[], metadata=dict())
|
|
129
|
+
self.episode_id = 0
|
|
130
|
+
self.save_dir = save_dir
|
|
131
|
+
self.save_on_close = save_on_close
|
|
132
|
+
self.save_on_reset = save_on_reset
|
|
133
|
+
self.episode_steps = 0
|
|
134
|
+
if save_dir is not None:
|
|
135
|
+
from pathlib import Path
|
|
136
|
+
|
|
137
|
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
138
|
+
|
|
139
|
+
def reset(
|
|
140
|
+
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
141
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
142
|
+
if self.save_on_reset and self.episode_steps > 0:
|
|
143
|
+
self._save_episode_and_reset()
|
|
144
|
+
obs, info = self.env.reset(seed=seed, options=options)
|
|
145
|
+
|
|
146
|
+
self.episode["metadata"]["seed"] = seed
|
|
147
|
+
self.episode["params"] = flax.serialization.to_state_dict(info["full_params"])
|
|
148
|
+
self.episode["states"].append(info["state"])
|
|
149
|
+
return obs, info
|
|
150
|
+
|
|
151
|
+
def step(
|
|
152
|
+
self, action: Any
|
|
153
|
+
) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
154
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
155
|
+
self.episode_steps += 1
|
|
156
|
+
self.episode["states"].append(info["final_state"])
|
|
157
|
+
self.episode["actions"].append(action)
|
|
158
|
+
return obs, reward, terminated, truncated, info
|
|
159
|
+
|
|
160
|
+
def serialize_episode_data(self, episode=None):
|
|
161
|
+
if episode is None:
|
|
162
|
+
episode = self.episode
|
|
163
|
+
ret = dict()
|
|
164
|
+
ret["observations"] = serialize_env_states(episode["states"])
|
|
165
|
+
if "actions" in episode:
|
|
166
|
+
ret["actions"] = serialize_env_actions(episode["actions"])
|
|
167
|
+
ret["metadata"] = episode["metadata"]
|
|
168
|
+
ret["params"] = episode["params"]
|
|
169
|
+
return ret
|
|
170
|
+
|
|
171
|
+
def save_episode(self, save_path: str):
|
|
172
|
+
episode = self.serialize_episode_data()
|
|
173
|
+
with open(save_path, "w") as f:
|
|
174
|
+
json.dump(episode, f)
|
|
175
|
+
self.episode = dict(states=[], actions=[], metadata=dict())
|
|
176
|
+
|
|
177
|
+
def _save_episode_and_reset(self):
|
|
178
|
+
"""saves to generated path based on self.save_dir and episoe id and updates relevant counters"""
|
|
179
|
+
self.save_episode(
|
|
180
|
+
os.path.join(self.save_dir, f"episode_{self.episode_id}.json")
|
|
181
|
+
)
|
|
182
|
+
self.episode_id += 1
|
|
183
|
+
self.episode_steps = 0
|
|
184
|
+
|
|
185
|
+
def close(self):
|
|
186
|
+
if self.save_on_close and self.episode_steps > 0:
|
|
187
|
+
self._save_episode_and_reset()
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
if __package__ == "":
|
|
2
|
+
from lux.utils import direction_to
|
|
3
|
+
else:
|
|
4
|
+
from .lux.utils import direction_to
|
|
5
|
+
import numpy as np
|
|
6
|
+
class Agent():
|
|
7
|
+
def __init__(self, player: str, env_cfg) -> None:
|
|
8
|
+
self.player = player
|
|
9
|
+
self.opp_player = "player_1" if self.player == "player_0" else "player_0"
|
|
10
|
+
self.team_id = 0 if self.player == "player_0" else 1
|
|
11
|
+
self.opp_team_id = 1 if self.team_id == 0 else 0
|
|
12
|
+
np.random.seed(0)
|
|
13
|
+
self.env_cfg = env_cfg
|
|
14
|
+
|
|
15
|
+
self.relic_node_positions = []
|
|
16
|
+
self.discovered_relic_nodes_ids = set()
|
|
17
|
+
self.unit_explore_locations = dict()
|
|
18
|
+
|
|
19
|
+
def act(self, step: int, obs, remainingOverageTime: int = 60):
|
|
20
|
+
"""implement this function to decide what actions to send to each available unit.
|
|
21
|
+
|
|
22
|
+
step is the current timestep number of the game starting from 0 going up to max_steps_in_match * match_count_per_episode - 1.
|
|
23
|
+
"""
|
|
24
|
+
unit_mask = np.array(obs["units_mask"][self.team_id]) # shape (max_units, )
|
|
25
|
+
unit_positions = np.array(obs["units"]["position"][self.team_id]) # shape (max_units, 2)
|
|
26
|
+
unit_energys = np.array(obs["units"]["energy"][self.team_id]) # shape (max_units, 1)
|
|
27
|
+
observed_relic_node_positions = np.array(obs["relic_nodes"]) # shape (max_relic_nodes, 2)
|
|
28
|
+
observed_relic_nodes_mask = np.array(obs["relic_nodes_mask"]) # shape (max_relic_nodes, )
|
|
29
|
+
team_points = np.array(obs["team_points"]) # points of each team, team_points[self.team_id] is the points of the your team
|
|
30
|
+
|
|
31
|
+
# ids of units you can control at this timestep
|
|
32
|
+
available_unit_ids = np.where(unit_mask)[0]
|
|
33
|
+
# visible relic nodes
|
|
34
|
+
visible_relic_node_ids = set(np.where(observed_relic_nodes_mask)[0])
|
|
35
|
+
|
|
36
|
+
actions = np.zeros((self.env_cfg["max_units"], 3), dtype=int)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# basic strategy here is simply to have some units randomly explore and some units collecting as much energy as possible
|
|
40
|
+
# and once a relic node is found, we send all units to move randomly around the first relic node to gain points
|
|
41
|
+
# and information about where relic nodes are found are saved for the next match
|
|
42
|
+
|
|
43
|
+
# save any new relic nodes that we discover for the rest of the game.
|
|
44
|
+
for id in visible_relic_node_ids:
|
|
45
|
+
if id not in self.discovered_relic_nodes_ids:
|
|
46
|
+
self.discovered_relic_nodes_ids.add(id)
|
|
47
|
+
self.relic_node_positions.append(observed_relic_node_positions[id])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# unit ids range from 0 to max_units - 1
|
|
51
|
+
for unit_id in available_unit_ids:
|
|
52
|
+
unit_pos = unit_positions[unit_id]
|
|
53
|
+
unit_energy = unit_energys[unit_id]
|
|
54
|
+
if len(self.relic_node_positions) > 0:
|
|
55
|
+
nearest_relic_node_position = self.relic_node_positions[0]
|
|
56
|
+
manhattan_distance = abs(unit_pos[0] - nearest_relic_node_position[0]) + abs(unit_pos[1] - nearest_relic_node_position[1])
|
|
57
|
+
|
|
58
|
+
# if close to the relic node we want to hover around it and hope to gain points
|
|
59
|
+
if manhattan_distance <= 4:
|
|
60
|
+
random_direction = np.random.randint(0, 5)
|
|
61
|
+
actions[unit_id] = [random_direction, 0, 0]
|
|
62
|
+
else:
|
|
63
|
+
# otherwise we want to move towards the relic node
|
|
64
|
+
actions[unit_id] = [direction_to(unit_pos, nearest_relic_node_position), 0, 0]
|
|
65
|
+
else:
|
|
66
|
+
# randomly explore by picking a random location on the map and moving there for about 20 steps
|
|
67
|
+
if step % 20 == 0 or unit_id not in self.unit_explore_locations:
|
|
68
|
+
rand_loc = (np.random.randint(0, self.env_cfg["map_width"]), np.random.randint(0, self.env_cfg["map_height"]))
|
|
69
|
+
self.unit_explore_locations[unit_id] = rand_loc
|
|
70
|
+
actions[unit_id] = [direction_to(unit_pos, self.unit_explore_locations[unit_id]), 0, 0]
|
|
71
|
+
return actions
|
|
File without changes
|