parabellum 0.2.13__tar.gz → 0.2.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum-0.2.15/PKG-INFO +104 -0
- parabellum-0.2.15/README.md +78 -0
- parabellum-0.2.15/parabellum/__init__.py +4 -0
- {parabellum-0.2.13 → parabellum-0.2.15}/parabellum/env.py +145 -49
- parabellum-0.2.15/parabellum/map.py +16 -0
- parabellum-0.2.15/parabellum/run.py +127 -0
- {parabellum-0.2.13 → parabellum-0.2.15}/parabellum/vis.py +84 -55
- {parabellum-0.2.13 → parabellum-0.2.15}/pyproject.toml +3 -2
- parabellum-0.2.13/PKG-INFO +0 -56
- parabellum-0.2.13/README.md +0 -31
- parabellum-0.2.13/parabellum/.ipynb_checkpoints/__init__-checkpoint.py +0 -4
- parabellum-0.2.13/parabellum/.ipynb_checkpoints/env-checkpoint.py +0 -296
- parabellum-0.2.13/parabellum/.ipynb_checkpoints/vis-checkpoint.py +0 -230
- parabellum-0.2.13/parabellum/__init__.py +0 -4
@@ -0,0 +1,104 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.2.15
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
8
|
+
Author: Noah Syrkis
|
9
|
+
Author-email: desk@syrkis.com
|
10
|
+
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: jax (==0.4.17)
|
17
|
+
Requires-Dist: jaxmarl (==0.0.3)
|
18
|
+
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
19
|
+
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
20
|
+
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
21
|
+
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
22
|
+
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
23
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
24
|
+
Description-Content-Type: text/markdown
|
25
|
+
|
26
|
+
# parabellum
|
27
|
+
|
28
|
+
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
29
|
+
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
30
|
+
support a wide range of new features and improvements.
|
31
|
+
|
32
|
+
## Installation
|
33
|
+
|
34
|
+
Parabellum is written in Python 3.11 and can be installed using pip:
|
35
|
+
|
36
|
+
```bash
|
37
|
+
pip install parabellum
|
38
|
+
```
|
39
|
+
|
40
|
+
## Usage
|
41
|
+
|
42
|
+
Parabellum is designed to be used in conjunction with JAX, a high-performance
|
43
|
+
numerical computing library. Here is a simple example of how to use Parabellum
|
44
|
+
to simulate a game with 10 agents and 10 enemies, each taking random actions:
|
45
|
+
|
46
|
+
```python
|
47
|
+
import parabellum as pb
|
48
|
+
from jax import random
|
49
|
+
|
50
|
+
# define the scenario
|
51
|
+
kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
|
52
|
+
scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
|
53
|
+
|
54
|
+
# create the environment
|
55
|
+
kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
|
56
|
+
env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
|
57
|
+
|
58
|
+
# initiate stochasticity
|
59
|
+
rng = random.PRNGKey(0)
|
60
|
+
rng, key = random.split(rng)
|
61
|
+
|
62
|
+
# initialize the environment state
|
63
|
+
obs, state = env.reset(key)
|
64
|
+
state_sequence = []
|
65
|
+
|
66
|
+
for _ in range(1000):
|
67
|
+
|
68
|
+
# manage stochasticity
|
69
|
+
rng, rng_act, key_step = random.split(key)
|
70
|
+
key_act = random.split(rng_act, len(env.agents))
|
71
|
+
|
72
|
+
# sample actions and append to state sequence
|
73
|
+
act = {a: env.action_space(a).sample(k)
|
74
|
+
for a, k in zip(env.agents, key_act)}
|
75
|
+
|
76
|
+
# step the environment
|
77
|
+
state_sequence.append((key_act, state, act))
|
78
|
+
obs, state, reward, done, info = env.step(key_step, act, state)
|
79
|
+
|
80
|
+
|
81
|
+
# save visualization of the state sequence
|
82
|
+
vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
|
83
|
+
vis.animate()
|
84
|
+
```
|
85
|
+
|
86
|
+
|
87
|
+
## Features
|
88
|
+
|
89
|
+
- Obstacles — can be inserted in
|
90
|
+
|
91
|
+
## TODO
|
92
|
+
|
93
|
+
- [x] Parallel pygame vis
|
94
|
+
- [ ] Parallel bullet renderings
|
95
|
+
- [ ] Combine parallell plots into one (maybe out of parabellum scope)
|
96
|
+
- [ ] Color for health?
|
97
|
+
- [ ] Add the ability to see ongoing game.
|
98
|
+
- [ ] Bug test friendly fire.
|
99
|
+
- [x] Start sim from arbitrary state.
|
100
|
+
- [ ] Save when the episode ends in some state/obs variable
|
101
|
+
- [ ] Look for the source of the bug when using more Allies than Enemies
|
102
|
+
- [ ] Y inversed axis for parabellum visualization
|
103
|
+
- [ ] Units see through obstacles?
|
104
|
+
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# parabellum
|
2
|
+
|
3
|
+
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
4
|
+
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
5
|
+
support a wide range of new features and improvements.
|
6
|
+
|
7
|
+
## Installation
|
8
|
+
|
9
|
+
Parabellum is written in Python 3.11 and can be installed using pip:
|
10
|
+
|
11
|
+
```bash
|
12
|
+
pip install parabellum
|
13
|
+
```
|
14
|
+
|
15
|
+
## Usage
|
16
|
+
|
17
|
+
Parabellum is designed to be used in conjunction with JAX, a high-performance
|
18
|
+
numerical computing library. Here is a simple example of how to use Parabellum
|
19
|
+
to simulate a game with 10 agents and 10 enemies, each taking random actions:
|
20
|
+
|
21
|
+
```python
|
22
|
+
import parabellum as pb
|
23
|
+
from jax import random
|
24
|
+
|
25
|
+
# define the scenario
|
26
|
+
kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
|
27
|
+
scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
|
28
|
+
|
29
|
+
# create the environment
|
30
|
+
kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
|
31
|
+
env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
|
32
|
+
|
33
|
+
# initiate stochasticity
|
34
|
+
rng = random.PRNGKey(0)
|
35
|
+
rng, key = random.split(rng)
|
36
|
+
|
37
|
+
# initialize the environment state
|
38
|
+
obs, state = env.reset(key)
|
39
|
+
state_sequence = []
|
40
|
+
|
41
|
+
for _ in range(1000):
|
42
|
+
|
43
|
+
# manage stochasticity
|
44
|
+
rng, rng_act, key_step = random.split(key)
|
45
|
+
key_act = random.split(rng_act, len(env.agents))
|
46
|
+
|
47
|
+
# sample actions and append to state sequence
|
48
|
+
act = {a: env.action_space(a).sample(k)
|
49
|
+
for a, k in zip(env.agents, key_act)}
|
50
|
+
|
51
|
+
# step the environment
|
52
|
+
state_sequence.append((key_act, state, act))
|
53
|
+
obs, state, reward, done, info = env.step(key_step, act, state)
|
54
|
+
|
55
|
+
|
56
|
+
# save visualization of the state sequence
|
57
|
+
vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
|
58
|
+
vis.animate()
|
59
|
+
```
|
60
|
+
|
61
|
+
|
62
|
+
## Features
|
63
|
+
|
64
|
+
- Obstacles — can be inserted in
|
65
|
+
|
66
|
+
## TODO
|
67
|
+
|
68
|
+
- [x] Parallel pygame vis
|
69
|
+
- [ ] Parallel bullet renderings
|
70
|
+
- [ ] Combine parallell plots into one (maybe out of parabellum scope)
|
71
|
+
- [ ] Color for health?
|
72
|
+
- [ ] Add the ability to see ongoing game.
|
73
|
+
- [ ] Bug test friendly fire.
|
74
|
+
- [x] Start sim from arbitrary state.
|
75
|
+
- [ ] Save when the episode ends in some state/obs variable
|
76
|
+
- [ ] Look for the source of the bug when using more Allies than Enemies
|
77
|
+
- [ ] Y inversed axis for parabellum visualization
|
78
|
+
- [ ] Units see through obstacles?
|
@@ -7,6 +7,7 @@ from jax import random
|
|
7
7
|
from jax import jit
|
8
8
|
from flax.struct import dataclass
|
9
9
|
import chex
|
10
|
+
from jax import vmap
|
10
11
|
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
12
|
from typing import Tuple, Dict
|
12
13
|
from functools import partial
|
@@ -16,7 +17,9 @@ from functools import partial
|
|
16
17
|
class Scenario:
|
17
18
|
"""Parabellum scenario"""
|
18
19
|
|
19
|
-
|
20
|
+
terrain_raster: chex.Array
|
21
|
+
|
22
|
+
obstacle_coords: chex.Array # TODO: use map instead of obstacles
|
20
23
|
obstacle_deltas: chex.Array
|
21
24
|
|
22
25
|
unit_types: chex.Array
|
@@ -30,8 +33,9 @@ class Scenario:
|
|
30
33
|
# default scenario
|
31
34
|
scenarios = {
|
32
35
|
"default": Scenario(
|
33
|
-
jnp.
|
34
|
-
jnp.array([[
|
36
|
+
jnp.eye(128, dtype=jnp.uint8),
|
37
|
+
jnp.array([[80, 0], [16, 12]]),
|
38
|
+
jnp.array([[0, 80], [0, 20]]),
|
35
39
|
jnp.zeros((19,), dtype=jnp.uint8),
|
36
40
|
9,
|
37
41
|
10,
|
@@ -40,24 +44,78 @@ scenarios = {
|
|
40
44
|
|
41
45
|
|
42
46
|
class Parabellum(SMAX):
|
43
|
-
def __init__(
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
self.unit_type_attack_blasts =
|
51
|
-
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
-
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
47
|
+
def __init__(self, scenario: Scenario, **kwargs):
|
48
|
+
map_height, map_width = scenario.terrain_raster.shape
|
49
|
+
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
50
|
+
super(Parabellum, self).__init__(**args, **kwargs)
|
51
|
+
self.terrain_raster = scenario.terrain_raster
|
52
|
+
self.obstacle_coords = scenario.obstacle_coords
|
53
|
+
self.obstacle_deltas = scenario.obstacle_deltas
|
54
|
+
self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
|
53
55
|
self.max_steps = 200
|
54
|
-
# overwrite
|
55
|
-
|
56
|
-
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
56
|
+
self._push_units_away = lambda x: x # overwrite push units
|
57
|
+
|
58
|
+
@partial(jax.jit, static_argnums=(0,))
|
59
|
+
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
60
|
+
"""Environment-specific reset."""
|
61
|
+
key, team_0_key, team_1_key = jax.random.split(key, num=3)
|
62
|
+
team_0_start = jnp.stack(
|
63
|
+
[jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
|
64
|
+
)
|
65
|
+
team_0_start_noise = jax.random.uniform(
|
66
|
+
team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
|
67
|
+
)
|
68
|
+
team_0_start = team_0_start + team_0_start_noise
|
69
|
+
team_1_start = jnp.stack(
|
70
|
+
[jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
|
71
|
+
* self.num_enemies
|
72
|
+
)
|
73
|
+
team_1_start_noise = jax.random.uniform(
|
74
|
+
team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
|
75
|
+
)
|
76
|
+
team_1_start = team_1_start + team_1_start_noise
|
77
|
+
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
78
|
+
key, pos_key = jax.random.split(key)
|
79
|
+
generated_unit_positions = self.position_generator.generate(pos_key)
|
80
|
+
unit_positions = jax.lax.select(
|
81
|
+
self.smacv2_position_generation, generated_unit_positions, unit_positions
|
82
|
+
)
|
83
|
+
unit_teams = jnp.zeros((self.num_agents,))
|
84
|
+
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
85
|
+
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
86
|
+
# default behaviour spawn all marines
|
87
|
+
unit_types = (
|
88
|
+
jnp.zeros((self.num_agents,), dtype=jnp.uint8)
|
89
|
+
if self.scenario is None
|
90
|
+
else self.scenario
|
91
|
+
)
|
92
|
+
key, unit_type_key = jax.random.split(key)
|
93
|
+
generated_unit_types = self.unit_type_generator.generate(unit_type_key)
|
94
|
+
unit_types = jax.lax.select(
|
95
|
+
self.smacv2_unit_type_generation, generated_unit_types, unit_types
|
96
|
+
)
|
97
|
+
unit_health = self.unit_type_health[unit_types]
|
98
|
+
state = State(
|
99
|
+
unit_positions=unit_positions,
|
100
|
+
unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
|
101
|
+
unit_teams=unit_teams,
|
102
|
+
unit_health=unit_health,
|
103
|
+
unit_types=unit_types,
|
104
|
+
prev_movement_actions=jnp.zeros((self.num_agents, 2)),
|
105
|
+
prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
|
106
|
+
time=0,
|
107
|
+
terminal=False,
|
108
|
+
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
109
|
+
)
|
110
|
+
state = self._push_units_away(state)
|
111
|
+
obs = self.get_obs(state)
|
112
|
+
world_state = self.get_world_state(state)
|
113
|
+
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
114
|
+
return obs, state
|
115
|
+
|
116
|
+
def _our_push_units_away(
|
117
|
+
self, pos, unit_types, firmness: float = 1.0
|
118
|
+
): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
119
|
delta_matrix = pos[:, None] - pos[None, :]
|
62
120
|
dist_matrix = (
|
63
121
|
jnp.linalg.norm(delta_matrix, axis=-1)
|
@@ -74,7 +132,7 @@ class Parabellum(SMAX):
|
|
74
132
|
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
133
|
)
|
76
134
|
return unit_positions
|
77
|
-
|
135
|
+
|
78
136
|
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
137
|
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
138
|
self,
|
@@ -82,15 +140,25 @@ class Parabellum(SMAX):
|
|
82
140
|
state: State,
|
83
141
|
actions: Tuple[chex.Array, chex.Array],
|
84
142
|
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
-
|
86
143
|
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
-
def
|
144
|
+
def intersect_fn(pos, new_pos, obs, obs_end):
|
88
145
|
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
146
|
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
147
|
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
148
|
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
149
|
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
150
|
|
151
|
+
def raster_crossing(pos, new_pos):
|
152
|
+
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
153
|
+
raster = self.terrain_raster
|
154
|
+
axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
|
155
|
+
minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
|
156
|
+
maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
|
157
|
+
segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
|
158
|
+
segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
|
159
|
+
segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
|
160
|
+
return jnp.any(segment)
|
161
|
+
|
94
162
|
def update_position(idx, vec):
|
95
163
|
# Compute the movements slightly strangely.
|
96
164
|
# The velocities below are for diagonal directions
|
@@ -114,8 +182,10 @@ class Parabellum(SMAX):
|
|
114
182
|
############################################ avoid going into obstacles
|
115
183
|
obs = self.obstacle_coords
|
116
184
|
obs_end = obs + self.obstacle_deltas
|
117
|
-
inters = jnp.any(
|
118
|
-
|
185
|
+
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
|
186
|
+
rastersects = raster_crossing(pos, new_pos)
|
187
|
+
flag = jnp.logical_or(inters, rastersects)
|
188
|
+
new_pos = jnp.where(flag, pos, new_pos)
|
119
189
|
|
120
190
|
#######################################################################
|
121
191
|
#######################################################################
|
@@ -224,19 +294,41 @@ class Parabellum(SMAX):
|
|
224
294
|
perform_agent_action
|
225
295
|
)(jnp.arange(self.num_agents), actions, keys)
|
226
296
|
|
227
|
-
#
|
297
|
+
# units push each other
|
228
298
|
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
299
|
|
230
|
-
# avoid going into obstacles after being pushed
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
300
|
+
# avoid going into obstacles after being pushed
|
301
|
+
|
302
|
+
bondaries_coords = jnp.array(
|
303
|
+
[[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
|
304
|
+
)
|
305
|
+
bondaries_deltas = jnp.array(
|
306
|
+
[
|
307
|
+
[self.map_width, 0],
|
308
|
+
[0, self.map_height],
|
309
|
+
[0, self.map_height],
|
310
|
+
[self.map_width, 0],
|
311
|
+
]
|
312
|
+
)
|
313
|
+
obstacle_coords = jnp.concatenate(
|
314
|
+
[self.obstacle_coords, bondaries_coords]
|
315
|
+
) # add the map boundaries to the obstacles to avoid
|
316
|
+
obstacle_deltas = jnp.concatenate(
|
317
|
+
[self.obstacle_deltas, bondaries_deltas]
|
318
|
+
) # add the map boundaries to the obstacles to avoid
|
319
|
+
obst_start = obstacle_coords
|
320
|
+
obst_end = obst_start + obstacle_deltas
|
321
|
+
|
322
|
+
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
323
|
+
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
324
|
+
rastersect = raster_crossing(pos, new_pos)
|
325
|
+
flag = jnp.logical_or(intersects, rastersect)
|
326
|
+
return jnp.where(flag, pos, new_pos)
|
327
|
+
|
328
|
+
pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
329
|
+
pos, new_pos, obst_start, obst_end
|
330
|
+
)
|
331
|
+
|
240
332
|
# Multiple enemies can attack the same unit.
|
241
333
|
# We have `(health_diff, attacked_idx)` pairs.
|
242
334
|
# `jax.lax.scatter_add` aggregates these exactly
|
@@ -278,19 +370,23 @@ class Parabellum(SMAX):
|
|
278
370
|
)
|
279
371
|
return state
|
280
372
|
|
373
|
+
|
281
374
|
if __name__ == "__main__":
|
282
|
-
|
283
|
-
|
284
|
-
|
375
|
+
n_envs = 4
|
376
|
+
kwargs = dict(map_width=64, map_height=64)
|
377
|
+
env = Parabellum(scenarios["default"], **kwargs)
|
378
|
+
rng, reset_rng = random.split(random.PRNGKey(0))
|
379
|
+
reset_key = random.split(reset_rng, n_envs)
|
380
|
+
obs, state = vmap(env.reset)(reset_key)
|
285
381
|
state_seq = []
|
286
|
-
for step in range(100):
|
287
|
-
rng, key = random.split(rng)
|
288
|
-
key_act = random.split(key, len(env.agents))
|
289
|
-
actions = {
|
290
|
-
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
-
for i, agent in enumerate(env.agents)
|
292
|
-
}
|
293
|
-
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
-
state_seq.append((obs, state, actions))
|
295
|
-
|
296
382
|
|
383
|
+
for i in range(10):
|
384
|
+
rng, act_rng, step_rng = random.split(rng, 3)
|
385
|
+
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
386
|
+
act = {
|
387
|
+
a: vmap(env.action_space(a).sample)(act_key[i])
|
388
|
+
for i, a in enumerate(env.agents)
|
389
|
+
}
|
390
|
+
step_key = random.split(step_rng, n_envs)
|
391
|
+
state_seq.append((step_key, state, act))
|
392
|
+
obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# map.py
|
2
|
+
# parabellum map functions
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax
|
8
|
+
|
9
|
+
|
10
|
+
# functions
|
11
|
+
def map_fn(width, height, obst_coord, obst_delta):
|
12
|
+
"""Create a map from the given width, height, and obstacle coordinates and deltas."""
|
13
|
+
m = jnp.zeros((width, height))
|
14
|
+
for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
|
15
|
+
m = m.at[x : x + dx, y : y + dy].set(1)
|
16
|
+
return m
|
@@ -0,0 +1,127 @@
|
|
1
|
+
# run.py
|
2
|
+
# parabellum run game live
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# Noah Syrkis
|
6
|
+
import pygame
|
7
|
+
from jax import random
|
8
|
+
from functools import partial
|
9
|
+
import darkdetect
|
10
|
+
import jax.numpy as jnp
|
11
|
+
from chex import dataclass
|
12
|
+
import jaxmarl
|
13
|
+
from typing import Tuple, List, Dict, Optional
|
14
|
+
import parabellum as pb
|
15
|
+
|
16
|
+
|
17
|
+
# constants
|
18
|
+
fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
|
19
|
+
bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
20
|
+
|
21
|
+
|
22
|
+
# types
|
23
|
+
State = jaxmarl.environments.smax.smax_env.State
|
24
|
+
Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
|
25
|
+
StateSeq = List[Tuple[jnp.ndarray, State, Action]]
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class Control:
|
30
|
+
running: bool = True
|
31
|
+
paused: bool = False
|
32
|
+
click: Optional[Tuple[int, int]] = None
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class Game:
|
37
|
+
clock: pygame.time.Clock
|
38
|
+
state: State
|
39
|
+
obs: Dict
|
40
|
+
state_seq: StateSeq
|
41
|
+
control: Control
|
42
|
+
env: pb.Parabellum
|
43
|
+
rng: random.PRNGKey
|
44
|
+
|
45
|
+
|
46
|
+
def handle_event(event, control_state):
|
47
|
+
"""Handle pygame events."""
|
48
|
+
if event.type == pygame.QUIT:
|
49
|
+
control_state.running = False
|
50
|
+
if event.type == pygame.MOUSEBUTTONDOWN:
|
51
|
+
pos = pygame.mouse.get_pos()
|
52
|
+
control_state.click = pos
|
53
|
+
if event.type == pygame.MOUSEBUTTONUP:
|
54
|
+
control_state.click = None
|
55
|
+
if event.type == pygame.KEYDOWN: # any key press pauses
|
56
|
+
control_state.paused = not control_state.paused
|
57
|
+
return control_state
|
58
|
+
|
59
|
+
|
60
|
+
def control_fn(game):
|
61
|
+
"""Handle pygame events."""
|
62
|
+
for event in pygame.event.get():
|
63
|
+
game.control = handle_event(event, game.control)
|
64
|
+
return game
|
65
|
+
|
66
|
+
|
67
|
+
def render_fn(screen, game):
|
68
|
+
"""Render the game."""
|
69
|
+
if len(game.state_seq) < 3:
|
70
|
+
return game
|
71
|
+
for rng, state, action in env.expand_state_seq(game.state_seq[-2:])[-8:]:
|
72
|
+
screen.fill(bg)
|
73
|
+
if game.control.click is not None:
|
74
|
+
pygame.draw.circle(screen, "red", game.control.click, 10)
|
75
|
+
unit_positions = state.unit_positions
|
76
|
+
for pos in unit_positions:
|
77
|
+
pos = (pos / env.map_width * 800).tolist()
|
78
|
+
pygame.draw.circle(screen, fg, pos, 5)
|
79
|
+
pygame.display.flip()
|
80
|
+
game.clock.tick(24) # limits FPS to 24
|
81
|
+
return game
|
82
|
+
|
83
|
+
|
84
|
+
def step_fn(game):
|
85
|
+
"""Step in parabellum."""
|
86
|
+
rng, act_rng, step_key = random.split(game.rng, 3)
|
87
|
+
act_key = random.split(act_rng, env.num_agents)
|
88
|
+
action = {
|
89
|
+
a: env.action_space(a).sample(act_key[i]) for i, a in enumerate(env.agents)
|
90
|
+
}
|
91
|
+
state_seq_entry = (step_key, game.state, action)
|
92
|
+
# append state_seq_entry to state_seq
|
93
|
+
game.state_seq.append(state_seq_entry)
|
94
|
+
obs, state, reward, done, info = env.step(step_key, game.state, action)
|
95
|
+
game.state = state
|
96
|
+
game.obs = obs
|
97
|
+
game.rng = rng
|
98
|
+
return game
|
99
|
+
|
100
|
+
|
101
|
+
# state
|
102
|
+
if __name__ == "__main__":
|
103
|
+
env = pb.Parabellum(pb.scenarios["default"])
|
104
|
+
pygame.init()
|
105
|
+
screen = pygame.display.set_mode((1000, 1000))
|
106
|
+
render = partial(render_fn, screen)
|
107
|
+
rng, key = random.split(random.PRNGKey(0))
|
108
|
+
obs, state = env.reset(key)
|
109
|
+
kwargs = dict(
|
110
|
+
control=Control(),
|
111
|
+
env=env,
|
112
|
+
rng=rng,
|
113
|
+
state_seq=[], # [(key, state, action)]
|
114
|
+
clock=pygame.time.Clock(),
|
115
|
+
state=state,
|
116
|
+
obs=obs,
|
117
|
+
)
|
118
|
+
game = Game(**kwargs)
|
119
|
+
|
120
|
+
while game.control.running:
|
121
|
+
game = control_fn(game)
|
122
|
+
game = game if game.control.paused else step_fn(game)
|
123
|
+
game = game if game.control.paused else render(game)
|
124
|
+
|
125
|
+
pygame.quit()
|
126
|
+
|
127
|
+
|