kaggle-environments 1.15.2__py2.py3-none-any.whl → 1.16.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (69) hide show
  1. kaggle_environments/__init__.py +1 -1
  2. kaggle_environments/envs/chess/chess.js +63 -22
  3. kaggle_environments/envs/chess/chess.json +4 -4
  4. kaggle_environments/envs/chess/chess.py +209 -51
  5. kaggle_environments/envs/chess/test_chess.py +43 -1
  6. kaggle_environments/envs/connectx/connectx.ipynb +3183 -0
  7. kaggle_environments/envs/football/football.ipynb +75 -0
  8. kaggle_environments/envs/halite/halite.ipynb +44736 -0
  9. kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +112 -0
  10. kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
  11. kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
  12. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
  13. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
  14. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
  15. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
  16. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
  17. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
  18. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
  19. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
  20. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
  21. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
  22. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
  23. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
  24. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
  25. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
  26. kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
  27. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
  28. kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
  29. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
  30. kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
  31. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
  32. kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
  33. kaggle_environments/envs/lux_ai_2021/README.md +3 -0
  34. kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
  35. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
  36. kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
  37. kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
  38. kaggle_environments/envs/lux_ai_s2/.gitignore +1 -0
  39. kaggle_environments/envs/lux_ai_s2/README.md +21 -0
  40. kaggle_environments/envs/lux_ai_s2/luxai_s2/.DS_Store +0 -0
  41. kaggle_environments/envs/lux_ai_s2/luxai_s2/map_generator/.DS_Store +0 -0
  42. kaggle_environments/envs/lux_ai_s3/README.md +21 -0
  43. kaggle_environments/envs/lux_ai_s3/agents.py +4 -0
  44. kaggle_environments/envs/lux_ai_s3/index.html +42 -0
  45. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
  46. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +138 -0
  47. kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
  48. kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +924 -0
  49. kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +13 -0
  50. kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
  51. kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +140 -0
  52. kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +270 -0
  53. kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +30 -0
  54. kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +399 -0
  55. kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
  56. kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +187 -0
  57. kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +71 -0
  58. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
  59. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +27 -0
  60. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
  61. kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +53 -0
  62. kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
  63. kaggle_environments/envs/tictactoe/tictactoe.ipynb +1393 -0
  64. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/METADATA +2 -2
  65. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/RECORD +69 -11
  66. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/WHEEL +1 -1
  67. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/LICENSE +0 -0
  68. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/entry_points.txt +0 -0
  69. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
1
+ import functools
2
+ import chex
3
+ import flax
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ from flax import struct
8
+
9
+ from luxai_s3.params import MAP_TYPES, EnvParams
10
+ from luxai_s3.utils import to_numpy
11
+ EMPTY_TILE = 0
12
+ NEBULA_TILE = 1
13
+ ASTEROID_TILE = 2
14
+
15
+ ENERGY_NODE_FNS = [
16
+ lambda d, x, y, z: jnp.sin(d * x + y) * z, lambda d, x, y, z: (x / (d + 1) + y) * z
17
+ ]
18
+
19
+ @struct.dataclass
20
+ class UnitState:
21
+ position: chex.Array
22
+ """Position of the unit with shape (2) for x, y"""
23
+ energy: int
24
+ """Energy of the unit"""
25
+
26
+ @struct.dataclass
27
+ class MapTile:
28
+ energy: int
29
+ """Energy of the tile, generated via energy_nodes and energy_node_fns"""
30
+ tile_type: int
31
+ """Type of the tile"""
32
+
33
+ @struct.dataclass
34
+ class EnvState:
35
+ units: UnitState
36
+ """Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
37
+
38
+ 3 features are for position (x, y), and energy
39
+ """
40
+ units_mask: chex.Array
41
+ """Mask of units in the environment with shape (T, N) for T teams, N max units"""
42
+ energy_nodes: chex.Array
43
+ """Energy nodes in the environment with shape (N, 2) for N max energy nodes, and 2 features.
44
+
45
+ 2 features are for position (x, y)
46
+ """
47
+
48
+ energy_node_fns: chex.Array
49
+ """Energy node functions for computing the energy field of the map. They describe the function with a sequence of numbers
50
+
51
+ The first number is the function used. The subsequent numbers parameterize the function. The function is applied to distance of map tile to energy node and the function parameters.
52
+ """
53
+
54
+ # energy_field: chex.Array
55
+ # """Energy field in the environment with shape (H, W) for H height, W width. This is generated from other state"""
56
+
57
+ energy_nodes_mask: chex.Array
58
+ """Mask of energy nodes in the environment with shape (N) for N max energy nodes"""
59
+ relic_nodes: chex.Array
60
+ """Relic nodes in the environment with shape (N, 2) for N max relic nodes, and 2 features.
61
+
62
+ 2 features are for position (x, y)
63
+ """
64
+ relic_node_configs: chex.Array
65
+ """Relic node configs in the environment with shape (N, K, K) for N max relic nodes and a KxK relic configuration"""
66
+ relic_nodes_mask: chex.Array
67
+ """Mask of relic nodes in the environment with shape (N, ) for N max relic nodes"""
68
+ relic_nodes_map_weights: chex.Array
69
+ """Map of relic nodes in the environment with shape (H, W) for H height, W width. True if a relic node is present, False otherwise. This is generated from other state"""
70
+
71
+ map_features: MapTile
72
+ """Map features in the environment with shape (W, H, 2) for W width, H height
73
+ """
74
+
75
+ sensor_mask: chex.Array
76
+ """Sensor mask in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
77
+
78
+ vision_power_map: chex.Array
79
+ """Vision power map in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
80
+
81
+ team_points: chex.Array
82
+ """Team points in the environment with shape (T) for T teams"""
83
+ team_wins: chex.Array
84
+ """Team wins in the environment with shape (T) for T teams"""
85
+
86
+ steps: int = 0
87
+ """steps taken in the environment"""
88
+ match_steps: int = 0
89
+ """steps taken in the current match"""
90
+
91
+ @struct.dataclass
92
+ class EnvObs:
93
+ """Partial observation of environment"""
94
+ units: UnitState
95
+ """Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
96
+
97
+ 3 features are for position (x, y), and energy
98
+ """
99
+ units_mask: chex.Array
100
+ """Mask of units in the environment with shape (T, N) for T teams, N max units"""
101
+
102
+ sensor_mask: chex.Array
103
+
104
+ map_features: MapTile
105
+ """Map features in the environment with shape (W, H, 2) for W width, H height
106
+ """
107
+ relic_nodes: chex.Array
108
+ """Position of all relic nodes with shape (N, 2) for N max relic nodes and 2 features for position (x, y). Number is -1 if not visible"""
109
+ relic_nodes_mask: chex.Array
110
+ """Mask of all relic nodes with shape (N) for N max relic nodes"""
111
+ team_points: chex.Array
112
+ """Team points in the environment with shape (T) for T teams"""
113
+ team_wins: chex.Array
114
+ """Team wins in the environment with shape (T) for T teams"""
115
+ steps: int = 0
116
+ """steps taken in the environment"""
117
+ match_steps: int = 0
118
+ """steps taken in the current match"""
119
+
120
+
121
+
122
+ def serialize_env_states(env_states: list[EnvState]):
123
+ def serialize_array(root: EnvState, arr, key_path: str = ""):
124
+ if key_path in ["sensor_mask", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]:
125
+ return None
126
+ if key_path == "relic_nodes":
127
+ return root.relic_nodes[root.relic_nodes_mask].tolist()
128
+ if key_path == "relic_node_configs":
129
+ return root.relic_node_configs[root.relic_nodes_mask].tolist()
130
+ if key_path == "energy_nodes":
131
+ return root.energy_nodes[root.energy_nodes_mask].tolist()
132
+ if isinstance(arr, jnp.ndarray):
133
+ return arr.tolist()
134
+ elif isinstance(arr, dict):
135
+ ret = dict()
136
+ for k, v in arr.items():
137
+ new_key = key_path + "/" + k if key_path else k
138
+ new_val = serialize_array(root, v, new_key)
139
+ if new_val is not None:
140
+ ret[k] = new_val
141
+ return ret
142
+ return arr
143
+ steps = []
144
+ for state in env_states:
145
+ state_dict = flax.serialization.to_state_dict(state)
146
+ steps.append(serialize_array(state, state_dict))
147
+
148
+ return steps
149
+
150
+ def serialize_env_actions(env_actions: list):
151
+ def serialize_array(arr, key_path: str = ""):
152
+ if isinstance(arr, np.ndarray):
153
+ return arr.tolist()
154
+ elif isinstance(arr, jnp.ndarray):
155
+ return arr.tolist()
156
+ elif isinstance(arr, dict):
157
+ ret = dict()
158
+ for k, v in arr.items():
159
+ new_key = key_path + "/" + k if key_path else k
160
+ new_val = serialize_array(v, new_key)
161
+ if new_val is not None:
162
+ ret[k] = new_val
163
+ return ret
164
+
165
+ return arr
166
+ steps = []
167
+ for state in env_actions:
168
+ state = flax.serialization.to_state_dict(state)
169
+ steps.append(serialize_array(state))
170
+
171
+ return steps
172
+
173
+ def state_to_flat_obs(state: EnvState) -> chex.Array:
174
+ pass
175
+
176
+
177
+ def flat_obs_to_state(flat_obs: chex.Array) -> EnvState:
178
+ pass
179
+
180
+ @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9))
181
+ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_teams: int, map_type: int, map_width: int, map_height: int, max_energy_nodes: int, max_relic_nodes: int, relic_config_size: int) -> EnvState:
182
+ generated = gen_map(key, env_params, map_type, map_width, map_height, max_energy_nodes, max_relic_nodes, relic_config_size)
183
+ relic_nodes_map_weights = jnp.zeros(
184
+ shape=(map_width, map_height), dtype=jnp.int16
185
+ )
186
+
187
+ # TODO (this could be optimized better)
188
+ def update_relic_node(relic_nodes_map_weights, relic_data):
189
+ relic_node, relic_node_config, mask = relic_data
190
+ start_y = relic_node[1] - relic_config_size // 2
191
+ start_x = relic_node[0] - relic_config_size // 2
192
+ for dy in range(relic_config_size):
193
+ for dx in range(relic_config_size):
194
+ y, x = start_y + dy, start_x + dx
195
+ valid_pos = jnp.logical_and(
196
+ jnp.logical_and(y >= 0, x >= 0),
197
+ jnp.logical_and(y < map_height, x < map_width),
198
+ )
199
+ relic_nodes_map_weights = jnp.where(
200
+ valid_pos & mask,
201
+ relic_nodes_map_weights.at[x, y].add(relic_node_config[dx, dy].astype(jnp.int16)),
202
+ relic_nodes_map_weights,
203
+ )
204
+ return relic_nodes_map_weights, None
205
+
206
+ # this is really slow...
207
+ relic_nodes_map_weights, _ = jax.lax.scan(
208
+ update_relic_node,
209
+ relic_nodes_map_weights,
210
+ (
211
+ generated["relic_nodes"],
212
+ generated["relic_node_configs"],
213
+ generated["relic_nodes_mask"],
214
+ ),
215
+ )
216
+ state = EnvState(
217
+ units=UnitState(position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16), energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16)),
218
+ units_mask=jnp.zeros(
219
+ shape=(num_teams, max_units), dtype=jnp.bool
220
+ ),
221
+ team_points=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
222
+ team_wins=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
223
+ energy_nodes=generated["energy_nodes"],
224
+ energy_node_fns=generated["energy_node_fns"],
225
+ energy_nodes_mask=generated["energy_nodes_mask"],
226
+ # energy_field=jnp.zeros(shape=(params.map_height, params.map_width), dtype=jnp.int16),
227
+ relic_nodes=generated["relic_nodes"],
228
+ relic_nodes_mask=generated["relic_nodes_mask"],
229
+ relic_node_configs=generated["relic_node_configs"],
230
+ relic_nodes_map_weights=relic_nodes_map_weights,
231
+ sensor_mask=jnp.zeros(
232
+ shape=(num_teams, map_height, map_width),
233
+ dtype=jnp.bool,
234
+ ),
235
+ vision_power_map=jnp.zeros(shape=(num_teams, map_height, map_width), dtype=jnp.int16),
236
+ map_features=generated["map_features"],
237
+ )
238
+ return state
239
+
240
+ @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7))
241
+ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int, map_width: int, max_energy_nodes: int, max_relic_nodes: int, relic_config_size: int) -> chex.Array:
242
+ map_features = MapTile(energy=jnp.zeros(
243
+ shape=(map_height, map_width), dtype=jnp.int16
244
+ ), tile_type=jnp.zeros(
245
+ shape=(map_height, map_width), dtype=jnp.int16
246
+ ))
247
+ energy_nodes = jnp.zeros(shape=(max_energy_nodes, 2), dtype=jnp.int16)
248
+ energy_nodes_mask = jnp.zeros(shape=(max_energy_nodes), dtype=jnp.bool)
249
+ relic_nodes = jnp.zeros(shape=(max_relic_nodes, 2), dtype=jnp.int16)
250
+ relic_nodes_mask = jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool)
251
+
252
+ if MAP_TYPES[map_type] == "random":
253
+
254
+ ### Generate nebula tiles ###
255
+ key, subkey = jax.random.split(key)
256
+ perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
257
+ noise = jnp.where(perlin_noise > 0.5, 1, 0)
258
+ # mirror along diagonal
259
+ noise = noise | noise.T
260
+ noise = noise[::-1, ::1]
261
+ map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0))
262
+
263
+ ### Generate asteroid tiles ###
264
+ key, subkey = jax.random.split(key)
265
+ perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (8, 8))
266
+ noise = jnp.where(perlin_noise < -0.5, 1, 0)
267
+ # mirror along diagonal
268
+ noise = noise | noise.T
269
+ noise = noise[::-1, ::1]
270
+ map_features = map_features.replace(tile_type=jnp.place(map_features.tile_type, noise, ASTEROID_TILE, inplace=False))
271
+
272
+ ### Generate relic nodes ###
273
+ key, subkey = jax.random.split(key)
274
+ noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
275
+ # Find the positions of the highest noise values
276
+ flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2:] # Get indices of two highest values
277
+ highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
278
+
279
+ # relic nodes have a fixed density of 25% nearby tiles can yield points
280
+ relic_node_configs = (
281
+ jax.random.randint(
282
+ key,
283
+ shape=(
284
+ max_relic_nodes,
285
+ relic_config_size,
286
+ relic_config_size,
287
+ ),
288
+ minval=0,
289
+ maxval=10,
290
+ ).astype(jnp.float32)
291
+ >= 7.5
292
+ )
293
+ highest_positions = highest_positions.astype(jnp.int16)
294
+ relic_nodes_mask = relic_nodes_mask.at[0].set(True)
295
+ relic_nodes_mask = relic_nodes_mask.at[1].set(True)
296
+ mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
297
+ relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
298
+
299
+ key, subkey = jax.random.split(key)
300
+ relic_nodes_mask_half = jax.random.randint(key, (max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
301
+ relic_nodes_mask_half = relic_nodes_mask_half.at[0].set(True)
302
+ relic_nodes_mask = relic_nodes_mask.at[:max_relic_nodes // 2].set(relic_nodes_mask_half)
303
+ relic_nodes_mask = relic_nodes_mask.at[max_relic_nodes // 2:].set(relic_nodes_mask_half)
304
+ # import ipdb;ipdb.set_trace()
305
+ relic_node_configs = relic_node_configs.at[max_relic_nodes // 2:].set(relic_node_configs[:max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1])
306
+
307
+ ### Generate energy nodes ###
308
+ key, subkey = jax.random.split(key)
309
+ noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
310
+ # Find the positions of the highest noise values
311
+ flat_indices = jnp.argsort(noise.ravel())[-max_energy_nodes // 2:] # Get indices of highest values
312
+ highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)).astype(jnp.int16)
313
+ mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
314
+ energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
315
+ key, subkey = jax.random.split(key)
316
+ energy_nodes_mask_half = jax.random.randint(key, (max_energy_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
317
+ energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True)
318
+ energy_nodes_mask = energy_nodes_mask.at[:max_energy_nodes // 2].set(energy_nodes_mask_half)
319
+ energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2:].set(energy_nodes_mask_half)
320
+
321
+ # TODO (stao): provide more randomization options for energy node functions.
322
+ energy_node_fns = jnp.array(
323
+ [
324
+ [0, 1.2, 1, 4],
325
+ [0, 0, 0, 0],
326
+ [0, 0, 0, 0],
327
+ # [1, 4, 0, 2],
328
+ [0, 1.2, 1, 4],
329
+ [0, 0, 0, 0],
330
+ [0, 0, 0, 0],
331
+ # [1, 4, 0, 0]
332
+ ]
333
+ )
334
+ # import ipdb; ipdb.set_trace()
335
+ # energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0)
336
+
337
+
338
+ return dict(
339
+ map_features=map_features,
340
+ energy_nodes=energy_nodes,
341
+ energy_node_fns=energy_node_fns,
342
+ relic_nodes=relic_nodes,
343
+ energy_nodes_mask=energy_nodes_mask,
344
+ relic_nodes_mask=relic_nodes_mask,
345
+ relic_node_configs=relic_node_configs,
346
+ )
347
+ def interpolant(t):
348
+ return t*t*t*(t*(t*6 - 15) + 10)
349
+
350
+ @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
351
+ def generate_perlin_noise_2d(
352
+ key, shape, res, tileable=(False, False), interpolant=interpolant
353
+ ):
354
+ """Generate a 2D numpy array of perlin noise.
355
+
356
+ Args:
357
+ shape: The shape of the generated array (tuple of two ints).
358
+ This must be a multple of res.
359
+ res: The number of periods of noise to generate along each
360
+ axis (tuple of two ints). Note shape must be a multiple of
361
+ res.
362
+ tileable: If the noise should be tileable along each axis
363
+ (tuple of two bools). Defaults to (False, False).
364
+ interpolant: The interpolation function, defaults to
365
+ t*t*t*(t*(t*6 - 15) + 10).
366
+
367
+ Returns:
368
+ A numpy array of shape shape with the generated noise.
369
+
370
+ Raises:
371
+ ValueError: If shape is not a multiple of res.
372
+ """
373
+ delta = (res[0] / shape[0], res[1] / shape[1])
374
+ d = (shape[0] // res[0], shape[1] // res[1])
375
+ grid = jnp.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]]\
376
+ .transpose(1, 2, 0) % 1
377
+ # Gradients
378
+ angles = 2*jnp.pi*jax.random.uniform(key, (res[0]+1, res[1]+1))
379
+ gradients = jnp.dstack((jnp.cos(angles), jnp.sin(angles)))
380
+ if tileable[0]:
381
+ gradients[-1,:] = gradients[0,:]
382
+ if tileable[1]:
383
+ gradients[:,-1] = gradients[:,0]
384
+ gradients = gradients.repeat(d[0], 0).repeat(d[1], 1)
385
+ g00 = gradients[ :-d[0], :-d[1]]
386
+ g10 = gradients[d[0]: , :-d[1]]
387
+ g01 = gradients[ :-d[0],d[1]: ]
388
+ g11 = gradients[d[0]: ,d[1]: ]
389
+
390
+ # Ramps
391
+ n00 = jnp.sum(jnp.dstack((grid[:,:,0] , grid[:,:,1] )) * g00, 2)
392
+ n10 = jnp.sum(jnp.dstack((grid[:,:,0]-1, grid[:,:,1] )) * g10, 2)
393
+ n01 = jnp.sum(jnp.dstack((grid[:,:,0] , grid[:,:,1]-1)) * g01, 2)
394
+ n11 = jnp.sum(jnp.dstack((grid[:,:,0]-1, grid[:,:,1]-1)) * g11, 2)
395
+ # Interpolation
396
+ t = interpolant(grid)
397
+ n0 = n00*(1-t[:,:,0]) + t[:,:,0]*n10
398
+ n1 = n01*(1-t[:,:,0]) + t[:,:,0]*n11
399
+ return jnp.sqrt(2)*((1-t[:,:,1])*n0 + t[:,:,1]*n1)
@@ -0,0 +1,12 @@
1
+ import numpy as np
2
+
3
+
4
+ def to_numpy(x):
5
+ if isinstance(x, dict):
6
+ return {k: to_numpy(v) for k, v in x.items()}
7
+ elif isinstance(x, list):
8
+ return np.array(x)
9
+ elif isinstance(x, np.ndarray):
10
+ return x
11
+ else:
12
+ return np.array(x)
@@ -0,0 +1,187 @@
1
+ # TODO (stao): Add lux ai s3 env to gymnax api wrapper, which is the old gym api
2
+ import json
3
+ import os
4
+ from typing import Any, SupportsFloat
5
+ import flax
6
+ import flax.serialization
7
+ import gymnasium as gym
8
+ import gymnax
9
+ import gymnax.environments.spaces
10
+ import jax
11
+ import numpy as np
12
+ import dataclasses
13
+ from luxai_s3.env import LuxAIS3Env
14
+ from luxai_s3.params import EnvParams, env_params_ranges
15
+ from luxai_s3.state import serialize_env_actions, serialize_env_states
16
+ from luxai_s3.utils import to_numpy
17
+
18
+
19
+ class LuxAIS3GymEnv(gym.Env):
20
+ def __init__(self, numpy_output: bool = False):
21
+ self.numpy_output = numpy_output
22
+ self.rng_key = jax.random.key(0)
23
+ self.jax_env = LuxAIS3Env(auto_reset=False)
24
+ self.env_params: EnvParams = EnvParams()
25
+
26
+ # auto run compiling steps here:
27
+ # print("Running compilation steps")
28
+ key = jax.random.key(0)
29
+ # Reset the environment
30
+ dummy_env_params = EnvParams(map_type=1)
31
+ key, reset_key = jax.random.split(key)
32
+ obs, state = self.jax_env.reset(reset_key, params=dummy_env_params)
33
+ # Take a random action
34
+ key, subkey = jax.random.split(key)
35
+ action = self.jax_env.action_space(dummy_env_params).sample(subkey)
36
+ # Step the environment and compile. Not sure why 2 steps? are needed
37
+ for _ in range(2):
38
+ key, subkey = jax.random.split(key)
39
+ obs, state, reward, terminated, truncated, info = self.jax_env.step(
40
+ subkey, state, action, params=dummy_env_params
41
+ )
42
+ # print("Finish compilation steps")
43
+ low = np.zeros((self.env_params.max_units, 3))
44
+ low[:, 1:] = -self.env_params.unit_sap_range
45
+ high = np.ones((self.env_params.max_units, 3)) * 6
46
+ high[:, 1:] = self.env_params.unit_sap_range
47
+ self.action_space = gym.spaces.Dict(
48
+ dict(
49
+ player_0=gym.spaces.Box(low=low, high=high, dtype=np.int16),
50
+ player_1=gym.spaces.Box(low=low, high=high, dtype=np.int16),
51
+ )
52
+ )
53
+
54
+ def render(self):
55
+ self.jax_env.render(self.state, self.env_params)
56
+
57
+ def reset(
58
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
59
+ ) -> tuple[Any, dict[str, Any]]:
60
+ if seed is not None:
61
+ self.rng_key = jax.random.key(seed)
62
+ self.rng_key, reset_key = jax.random.split(self.rng_key)
63
+ # generate random game parameters
64
+ # TODO (stao): check why this keeps recompiling when marking structs as static args
65
+ randomized_game_params = dict()
66
+ for k, v in env_params_ranges.items():
67
+ self.rng_key, subkey = jax.random.split(self.rng_key)
68
+ randomized_game_params[k] = jax.random.choice(
69
+ subkey, jax.numpy.array(v)
70
+ ).item()
71
+ params = EnvParams(**randomized_game_params)
72
+ if options is not None and "params" in options:
73
+ params = options["params"]
74
+
75
+ self.env_params = params
76
+ obs, self.state = self.jax_env.reset(reset_key, params=params)
77
+ if self.numpy_output:
78
+ obs = to_numpy(flax.serialization.to_state_dict(obs))
79
+
80
+ # only keep the following game parameters available to the agent
81
+ params_dict = dataclasses.asdict(params)
82
+ params_dict_kept = dict()
83
+ for k in [
84
+ "max_units",
85
+ "match_count_per_episode",
86
+ "max_steps_in_match",
87
+ "map_height",
88
+ "map_width",
89
+ "num_teams",
90
+ "unit_move_cost",
91
+ "unit_sap_cost",
92
+ "unit_sap_range",
93
+ "unit_sensor_range",
94
+ ]:
95
+ params_dict_kept[k] = params_dict[k]
96
+ return obs, dict(
97
+ params=params_dict_kept, full_params=params_dict, state=self.state
98
+ )
99
+
100
+ def step(
101
+ self, action: Any
102
+ ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
103
+ self.rng_key, step_key = jax.random.split(self.rng_key)
104
+ obs, self.state, reward, terminated, truncated, info = self.jax_env.step(
105
+ step_key, self.state, action, self.env_params
106
+ )
107
+ if self.numpy_output:
108
+ obs = to_numpy(flax.serialization.to_state_dict(obs))
109
+ reward = to_numpy(reward)
110
+ terminated = to_numpy(terminated)
111
+ truncated = to_numpy(truncated)
112
+ # info = to_numpy(flax.serialization.to_state_dict(info))
113
+ return obs, reward, terminated, truncated, info
114
+
115
+
116
+ # TODO: vectorized gym wrapper
117
+
118
+
119
+ class RecordEpisode(gym.Wrapper):
120
+ def __init__(
121
+ self,
122
+ env: LuxAIS3GymEnv,
123
+ save_dir: str = None,
124
+ save_on_close: bool = True,
125
+ save_on_reset: bool = True,
126
+ ):
127
+ super().__init__(env)
128
+ self.episode = dict(states=[], actions=[], metadata=dict())
129
+ self.episode_id = 0
130
+ self.save_dir = save_dir
131
+ self.save_on_close = save_on_close
132
+ self.save_on_reset = save_on_reset
133
+ self.episode_steps = 0
134
+ if save_dir is not None:
135
+ from pathlib import Path
136
+
137
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
138
+
139
+ def reset(
140
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
141
+ ) -> tuple[Any, dict[str, Any]]:
142
+ if self.save_on_reset and self.episode_steps > 0:
143
+ self._save_episode_and_reset()
144
+ obs, info = self.env.reset(seed=seed, options=options)
145
+
146
+ self.episode["metadata"]["seed"] = seed
147
+ self.episode["params"] = flax.serialization.to_state_dict(info["full_params"])
148
+ self.episode["states"].append(info["state"])
149
+ return obs, info
150
+
151
+ def step(
152
+ self, action: Any
153
+ ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
154
+ obs, reward, terminated, truncated, info = self.env.step(action)
155
+ self.episode_steps += 1
156
+ self.episode["states"].append(info["final_state"])
157
+ self.episode["actions"].append(action)
158
+ return obs, reward, terminated, truncated, info
159
+
160
+ def serialize_episode_data(self, episode=None):
161
+ if episode is None:
162
+ episode = self.episode
163
+ ret = dict()
164
+ ret["observations"] = serialize_env_states(episode["states"])
165
+ if "actions" in episode:
166
+ ret["actions"] = serialize_env_actions(episode["actions"])
167
+ ret["metadata"] = episode["metadata"]
168
+ ret["params"] = episode["params"]
169
+ return ret
170
+
171
+ def save_episode(self, save_path: str):
172
+ episode = self.serialize_episode_data()
173
+ with open(save_path, "w") as f:
174
+ json.dump(episode, f)
175
+ self.episode = dict(states=[], actions=[], metadata=dict())
176
+
177
+ def _save_episode_and_reset(self):
178
+ """saves to generated path based on self.save_dir and episoe id and updates relevant counters"""
179
+ self.save_episode(
180
+ os.path.join(self.save_dir, f"episode_{self.episode_id}.json")
181
+ )
182
+ self.episode_id += 1
183
+ self.episode_steps = 0
184
+
185
+ def close(self):
186
+ if self.save_on_close and self.episode_steps > 0:
187
+ self._save_episode_and_reset()
@@ -0,0 +1,71 @@
1
+ if __package__ == "":
2
+ from lux.utils import direction_to
3
+ else:
4
+ from .lux.utils import direction_to
5
+ import numpy as np
6
+ class Agent():
7
+ def __init__(self, player: str, env_cfg) -> None:
8
+ self.player = player
9
+ self.opp_player = "player_1" if self.player == "player_0" else "player_0"
10
+ self.team_id = 0 if self.player == "player_0" else 1
11
+ self.opp_team_id = 1 if self.team_id == 0 else 0
12
+ np.random.seed(0)
13
+ self.env_cfg = env_cfg
14
+
15
+ self.relic_node_positions = []
16
+ self.discovered_relic_nodes_ids = set()
17
+ self.unit_explore_locations = dict()
18
+
19
+ def act(self, step: int, obs, remainingOverageTime: int = 60):
20
+ """implement this function to decide what actions to send to each available unit.
21
+
22
+ step is the current timestep number of the game starting from 0 going up to max_steps_in_match * match_count_per_episode - 1.
23
+ """
24
+ unit_mask = np.array(obs["units_mask"][self.team_id]) # shape (max_units, )
25
+ unit_positions = np.array(obs["units"]["position"][self.team_id]) # shape (max_units, 2)
26
+ unit_energys = np.array(obs["units"]["energy"][self.team_id]) # shape (max_units, 1)
27
+ observed_relic_node_positions = np.array(obs["relic_nodes"]) # shape (max_relic_nodes, 2)
28
+ observed_relic_nodes_mask = np.array(obs["relic_nodes_mask"]) # shape (max_relic_nodes, )
29
+ team_points = np.array(obs["team_points"]) # points of each team, team_points[self.team_id] is the points of the your team
30
+
31
+ # ids of units you can control at this timestep
32
+ available_unit_ids = np.where(unit_mask)[0]
33
+ # visible relic nodes
34
+ visible_relic_node_ids = set(np.where(observed_relic_nodes_mask)[0])
35
+
36
+ actions = np.zeros((self.env_cfg["max_units"], 3), dtype=int)
37
+
38
+
39
+ # basic strategy here is simply to have some units randomly explore and some units collecting as much energy as possible
40
+ # and once a relic node is found, we send all units to move randomly around the first relic node to gain points
41
+ # and information about where relic nodes are found are saved for the next match
42
+
43
+ # save any new relic nodes that we discover for the rest of the game.
44
+ for id in visible_relic_node_ids:
45
+ if id not in self.discovered_relic_nodes_ids:
46
+ self.discovered_relic_nodes_ids.add(id)
47
+ self.relic_node_positions.append(observed_relic_node_positions[id])
48
+
49
+
50
+ # unit ids range from 0 to max_units - 1
51
+ for unit_id in available_unit_ids:
52
+ unit_pos = unit_positions[unit_id]
53
+ unit_energy = unit_energys[unit_id]
54
+ if len(self.relic_node_positions) > 0:
55
+ nearest_relic_node_position = self.relic_node_positions[0]
56
+ manhattan_distance = abs(unit_pos[0] - nearest_relic_node_position[0]) + abs(unit_pos[1] - nearest_relic_node_position[1])
57
+
58
+ # if close to the relic node we want to hover around it and hope to gain points
59
+ if manhattan_distance <= 4:
60
+ random_direction = np.random.randint(0, 5)
61
+ actions[unit_id] = [random_direction, 0, 0]
62
+ else:
63
+ # otherwise we want to move towards the relic node
64
+ actions[unit_id] = [direction_to(unit_pos, nearest_relic_node_position), 0, 0]
65
+ else:
66
+ # randomly explore by picking a random location on the map and moving there for about 20 steps
67
+ if step % 20 == 0 or unit_id not in self.unit_explore_locations:
68
+ rand_loc = (np.random.randint(0, self.env_cfg["map_width"]), np.random.randint(0, self.env_cfg["map_height"]))
69
+ self.unit_explore_locations[unit_id] = rand_loc
70
+ actions[unit_id] = [direction_to(unit_pos, self.unit_explore_locations[unit_id]), 0, 0]
71
+ return actions