parabellum 0.2.13__tar.gz → 0.2.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,104 @@
1
+ Metadata-Version: 2.1
2
+ Name: parabellum
3
+ Version: 0.2.15
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Home-page: https://github.com/syrkis/parabellum
6
+ License: MIT
7
+ Keywords: warfare,simulation,parallel,environment
8
+ Author: Noah Syrkis
9
+ Author-email: desk@syrkis.com
10
+ Requires-Python: >=3.11,<4.0
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
+ Requires-Dist: jax (==0.4.17)
17
+ Requires-Dist: jaxmarl (==0.0.3)
18
+ Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
19
+ Requires-Dist: moviepy (>=1.0.3,<2.0.0)
20
+ Requires-Dist: poetry (>=1.8.3,<2.0.0)
21
+ Requires-Dist: pygame (>=2.5.2,<3.0.0)
22
+ Requires-Dist: tqdm (>=4.66.4,<5.0.0)
23
+ Project-URL: Repository, https://github.com/syrkis/parabellum
24
+ Description-Content-Type: text/markdown
25
+
26
+ # parabellum
27
+
28
+ Parabellum is an ultra-scalable, high-performance warfare simulation engine.
29
+ It is based on JaxMARL's SMAX environment, but has been heavily modified to
30
+ support a wide range of new features and improvements.
31
+
32
+ ## Installation
33
+
34
+ Parabellum is written in Python 3.11 and can be installed using pip:
35
+
36
+ ```bash
37
+ pip install parabellum
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ Parabellum is designed to be used in conjunction with JAX, a high-performance
43
+ numerical computing library. Here is a simple example of how to use Parabellum
44
+ to simulate a game with 10 agents and 10 enemies, each taking random actions:
45
+
46
+ ```python
47
+ import parabellum as pb
48
+ from jax import random
49
+
50
+ # define the scenario
51
+ kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
52
+ scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
53
+
54
+ # create the environment
55
+ kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
56
+ env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
57
+
58
+ # initiate stochasticity
59
+ rng = random.PRNGKey(0)
60
+ rng, key = random.split(rng)
61
+
62
+ # initialize the environment state
63
+ obs, state = env.reset(key)
64
+ state_sequence = []
65
+
66
+ for _ in range(1000):
67
+
68
+ # manage stochasticity
69
+ rng, rng_act, key_step = random.split(key)
70
+ key_act = random.split(rng_act, len(env.agents))
71
+
72
+ # sample actions and append to state sequence
73
+ act = {a: env.action_space(a).sample(k)
74
+ for a, k in zip(env.agents, key_act)}
75
+
76
+ # step the environment
77
+ state_sequence.append((key_act, state, act))
78
+ obs, state, reward, done, info = env.step(key_step, act, state)
79
+
80
+
81
+ # save visualization of the state sequence
82
+ vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
83
+ vis.animate()
84
+ ```
85
+
86
+
87
+ ## Features
88
+
89
+ - Obstacles — can be inserted in
90
+
91
+ ## TODO
92
+
93
+ - [x] Parallel pygame vis
94
+ - [ ] Parallel bullet renderings
95
+ - [ ] Combine parallell plots into one (maybe out of parabellum scope)
96
+ - [ ] Color for health?
97
+ - [ ] Add the ability to see ongoing game.
98
+ - [ ] Bug test friendly fire.
99
+ - [x] Start sim from arbitrary state.
100
+ - [ ] Save when the episode ends in some state/obs variable
101
+ - [ ] Look for the source of the bug when using more Allies than Enemies
102
+ - [ ] Y inversed axis for parabellum visualization
103
+ - [ ] Units see through obstacles?
104
+
@@ -0,0 +1,78 @@
1
+ # parabellum
2
+
3
+ Parabellum is an ultra-scalable, high-performance warfare simulation engine.
4
+ It is based on JaxMARL's SMAX environment, but has been heavily modified to
5
+ support a wide range of new features and improvements.
6
+
7
+ ## Installation
8
+
9
+ Parabellum is written in Python 3.11 and can be installed using pip:
10
+
11
+ ```bash
12
+ pip install parabellum
13
+ ```
14
+
15
+ ## Usage
16
+
17
+ Parabellum is designed to be used in conjunction with JAX, a high-performance
18
+ numerical computing library. Here is a simple example of how to use Parabellum
19
+ to simulate a game with 10 agents and 10 enemies, each taking random actions:
20
+
21
+ ```python
22
+ import parabellum as pb
23
+ from jax import random
24
+
25
+ # define the scenario
26
+ kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
27
+ scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
28
+
29
+ # create the environment
30
+ kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
31
+ env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
32
+
33
+ # initiate stochasticity
34
+ rng = random.PRNGKey(0)
35
+ rng, key = random.split(rng)
36
+
37
+ # initialize the environment state
38
+ obs, state = env.reset(key)
39
+ state_sequence = []
40
+
41
+ for _ in range(1000):
42
+
43
+ # manage stochasticity
44
+ rng, rng_act, key_step = random.split(key)
45
+ key_act = random.split(rng_act, len(env.agents))
46
+
47
+ # sample actions and append to state sequence
48
+ act = {a: env.action_space(a).sample(k)
49
+ for a, k in zip(env.agents, key_act)}
50
+
51
+ # step the environment
52
+ state_sequence.append((key_act, state, act))
53
+ obs, state, reward, done, info = env.step(key_step, act, state)
54
+
55
+
56
+ # save visualization of the state sequence
57
+ vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
58
+ vis.animate()
59
+ ```
60
+
61
+
62
+ ## Features
63
+
64
+ - Obstacles — can be inserted in
65
+
66
+ ## TODO
67
+
68
+ - [x] Parallel pygame vis
69
+ - [ ] Parallel bullet renderings
70
+ - [ ] Combine parallell plots into one (maybe out of parabellum scope)
71
+ - [ ] Color for health?
72
+ - [ ] Add the ability to see ongoing game.
73
+ - [ ] Bug test friendly fire.
74
+ - [x] Start sim from arbitrary state.
75
+ - [ ] Save when the episode ends in some state/obs variable
76
+ - [ ] Look for the source of the bug when using more Allies than Enemies
77
+ - [ ] Y inversed axis for parabellum visualization
78
+ - [ ] Units see through obstacles?
@@ -0,0 +1,4 @@
1
+ from .env import Parabellum, Scenario, scenarios
2
+ from .vis import Visualizer
3
+
4
+ __all__ = ["Parabellum", "Visualizer", "Scenario", "scenarios"]
@@ -7,6 +7,7 @@ from jax import random
7
7
  from jax import jit
8
8
  from flax.struct import dataclass
9
9
  import chex
10
+ from jax import vmap
10
11
  from jaxmarl.environments.smax.smax_env import State, SMAX
11
12
  from typing import Tuple, Dict
12
13
  from functools import partial
@@ -16,7 +17,9 @@ from functools import partial
16
17
  class Scenario:
17
18
  """Parabellum scenario"""
18
19
 
19
- obstacle_coords: chex.Array
20
+ terrain_raster: chex.Array
21
+
22
+ obstacle_coords: chex.Array # TODO: use map instead of obstacles
20
23
  obstacle_deltas: chex.Array
21
24
 
22
25
  unit_types: chex.Array
@@ -30,8 +33,9 @@ class Scenario:
30
33
  # default scenario
31
34
  scenarios = {
32
35
  "default": Scenario(
33
- jnp.array([[6, 10], [26, 10]]) * 8,
34
- jnp.array([[0, 12], [0, 1]]) * 8,
36
+ jnp.eye(128, dtype=jnp.uint8),
37
+ jnp.array([[80, 0], [16, 12]]),
38
+ jnp.array([[0, 80], [0, 20]]),
35
39
  jnp.zeros((19,), dtype=jnp.uint8),
36
40
  9,
37
41
  10,
@@ -40,24 +44,78 @@ scenarios = {
40
44
 
41
45
 
42
46
  class Parabellum(SMAX):
43
- def __init__(
44
- self,
45
- scenario: Scenario = scenarios["default"],
46
- unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
47
- **kwargs,
48
- ):
49
- super().__init__(scenario=scenario, **kwargs)
50
- self.unit_type_attack_blasts = unit_type_attack_blasts
51
- self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
52
- self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
47
+ def __init__(self, scenario: Scenario, **kwargs):
48
+ map_height, map_width = scenario.terrain_raster.shape
49
+ args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
50
+ super(Parabellum, self).__init__(**args, **kwargs)
51
+ self.terrain_raster = scenario.terrain_raster
52
+ self.obstacle_coords = scenario.obstacle_coords
53
+ self.obstacle_deltas = scenario.obstacle_deltas
54
+ self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
53
55
  self.max_steps = 200
54
- # overwrite supers _world_step method
55
-
56
-
57
- def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
- return state
59
-
60
- def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
56
+ self._push_units_away = lambda x: x # overwrite push units
57
+
58
+ @partial(jax.jit, static_argnums=(0,))
59
+ def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
60
+ """Environment-specific reset."""
61
+ key, team_0_key, team_1_key = jax.random.split(key, num=3)
62
+ team_0_start = jnp.stack(
63
+ [jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
64
+ )
65
+ team_0_start_noise = jax.random.uniform(
66
+ team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
67
+ )
68
+ team_0_start = team_0_start + team_0_start_noise
69
+ team_1_start = jnp.stack(
70
+ [jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
71
+ * self.num_enemies
72
+ )
73
+ team_1_start_noise = jax.random.uniform(
74
+ team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
75
+ )
76
+ team_1_start = team_1_start + team_1_start_noise
77
+ unit_positions = jnp.concatenate([team_0_start, team_1_start])
78
+ key, pos_key = jax.random.split(key)
79
+ generated_unit_positions = self.position_generator.generate(pos_key)
80
+ unit_positions = jax.lax.select(
81
+ self.smacv2_position_generation, generated_unit_positions, unit_positions
82
+ )
83
+ unit_teams = jnp.zeros((self.num_agents,))
84
+ unit_teams = unit_teams.at[self.num_allies :].set(1)
85
+ unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
86
+ # default behaviour spawn all marines
87
+ unit_types = (
88
+ jnp.zeros((self.num_agents,), dtype=jnp.uint8)
89
+ if self.scenario is None
90
+ else self.scenario
91
+ )
92
+ key, unit_type_key = jax.random.split(key)
93
+ generated_unit_types = self.unit_type_generator.generate(unit_type_key)
94
+ unit_types = jax.lax.select(
95
+ self.smacv2_unit_type_generation, generated_unit_types, unit_types
96
+ )
97
+ unit_health = self.unit_type_health[unit_types]
98
+ state = State(
99
+ unit_positions=unit_positions,
100
+ unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
101
+ unit_teams=unit_teams,
102
+ unit_health=unit_health,
103
+ unit_types=unit_types,
104
+ prev_movement_actions=jnp.zeros((self.num_agents, 2)),
105
+ prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
106
+ time=0,
107
+ terminal=False,
108
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
109
+ )
110
+ state = self._push_units_away(state)
111
+ obs = self.get_obs(state)
112
+ world_state = self.get_world_state(state)
113
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
114
+ return obs, state
115
+
116
+ def _our_push_units_away(
117
+ self, pos, unit_types, firmness: float = 1.0
118
+ ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
119
  delta_matrix = pos[:, None] - pos[None, :]
62
120
  dist_matrix = (
63
121
  jnp.linalg.norm(delta_matrix, axis=-1)
@@ -74,7 +132,7 @@ class Parabellum(SMAX):
74
132
  + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
133
  )
76
134
  return unit_positions
77
-
135
+
78
136
  @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
79
137
  def _world_step( # modified version of JaxMARL's SMAX _world_step
80
138
  self,
@@ -82,15 +140,25 @@ class Parabellum(SMAX):
82
140
  state: State,
83
141
  actions: Tuple[chex.Array, chex.Array],
84
142
  ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
-
86
143
  @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
- def inter_fn(pos, new_pos, obs, obs_end):
144
+ def intersect_fn(pos, new_pos, obs, obs_end):
88
145
  d1 = jnp.cross(obs - pos, new_pos - pos)
89
146
  d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
147
  d3 = jnp.cross(pos - obs, obs_end - obs)
91
148
  d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
149
  return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
150
 
151
+ def raster_crossing(pos, new_pos):
152
+ pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
153
+ raster = self.terrain_raster
154
+ axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
155
+ minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
156
+ maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
157
+ segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
158
+ segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
159
+ segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
160
+ return jnp.any(segment)
161
+
94
162
  def update_position(idx, vec):
95
163
  # Compute the movements slightly strangely.
96
164
  # The velocities below are for diagonal directions
@@ -114,8 +182,10 @@ class Parabellum(SMAX):
114
182
  ############################################ avoid going into obstacles
115
183
  obs = self.obstacle_coords
116
184
  obs_end = obs + self.obstacle_deltas
117
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
118
- new_pos = jnp.where(inters, pos, new_pos)
185
+ inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
186
+ rastersects = raster_crossing(pos, new_pos)
187
+ flag = jnp.logical_or(inters, rastersects)
188
+ new_pos = jnp.where(flag, pos, new_pos)
119
189
 
120
190
  #######################################################################
121
191
  #######################################################################
@@ -224,19 +294,41 @@ class Parabellum(SMAX):
224
294
  perform_agent_action
225
295
  )(jnp.arange(self.num_agents), actions, keys)
226
296
 
227
- # checked that no unit passed through an obstacles
297
+ # units push each other
228
298
  new_pos = self._our_push_units_away(pos, state.unit_types)
229
299
 
230
- # avoid going into obstacles after being pushed
231
- obs = self.obstacle_coords
232
- obs_end = obs + self.obstacle_deltas
233
-
234
- def check_obstacles(pos, new_pos, obs, obs_end):
235
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
- return jnp.where(inters, pos, new_pos)
237
-
238
- pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
-
300
+ # avoid going into obstacles after being pushed
301
+
302
+ bondaries_coords = jnp.array(
303
+ [[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
304
+ )
305
+ bondaries_deltas = jnp.array(
306
+ [
307
+ [self.map_width, 0],
308
+ [0, self.map_height],
309
+ [0, self.map_height],
310
+ [self.map_width, 0],
311
+ ]
312
+ )
313
+ obstacle_coords = jnp.concatenate(
314
+ [self.obstacle_coords, bondaries_coords]
315
+ ) # add the map boundaries to the obstacles to avoid
316
+ obstacle_deltas = jnp.concatenate(
317
+ [self.obstacle_deltas, bondaries_deltas]
318
+ ) # add the map boundaries to the obstacles to avoid
319
+ obst_start = obstacle_coords
320
+ obst_end = obst_start + obstacle_deltas
321
+
322
+ def check_obstacles(pos, new_pos, obst_start, obst_end):
323
+ intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
324
+ rastersect = raster_crossing(pos, new_pos)
325
+ flag = jnp.logical_or(intersects, rastersect)
326
+ return jnp.where(flag, pos, new_pos)
327
+
328
+ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
329
+ pos, new_pos, obst_start, obst_end
330
+ )
331
+
240
332
  # Multiple enemies can attack the same unit.
241
333
  # We have `(health_diff, attacked_idx)` pairs.
242
334
  # `jax.lax.scatter_add` aggregates these exactly
@@ -278,19 +370,23 @@ class Parabellum(SMAX):
278
370
  )
279
371
  return state
280
372
 
373
+
281
374
  if __name__ == "__main__":
282
- env = Parabellum(map_width=256, map_height=256)
283
- rng, key = random.split(random.PRNGKey(0))
284
- obs, state = env.reset(key)
375
+ n_envs = 4
376
+ kwargs = dict(map_width=64, map_height=64)
377
+ env = Parabellum(scenarios["default"], **kwargs)
378
+ rng, reset_rng = random.split(random.PRNGKey(0))
379
+ reset_key = random.split(reset_rng, n_envs)
380
+ obs, state = vmap(env.reset)(reset_key)
285
381
  state_seq = []
286
- for step in range(100):
287
- rng, key = random.split(rng)
288
- key_act = random.split(key, len(env.agents))
289
- actions = {
290
- agent: jax.random.randint(key_act[i], (), 0, 5)
291
- for i, agent in enumerate(env.agents)
292
- }
293
- _, state, _, _, _ = env.step(key, state, actions)
294
- state_seq.append((obs, state, actions))
295
-
296
382
 
383
+ for i in range(10):
384
+ rng, act_rng, step_rng = random.split(rng, 3)
385
+ act_key = random.split(act_rng, (len(env.agents), n_envs))
386
+ act = {
387
+ a: vmap(env.action_space(a).sample)(act_key[i])
388
+ for i, a in enumerate(env.agents)
389
+ }
390
+ step_key = random.split(step_rng, n_envs)
391
+ state_seq.append((step_key, state, act))
392
+ obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
@@ -0,0 +1,16 @@
1
+ # map.py
2
+ # parabellum map functions
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ import jax.numpy as jnp
7
+ import jax
8
+
9
+
10
+ # functions
11
+ def map_fn(width, height, obst_coord, obst_delta):
12
+ """Create a map from the given width, height, and obstacle coordinates and deltas."""
13
+ m = jnp.zeros((width, height))
14
+ for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
15
+ m = m.at[x : x + dx, y : y + dy].set(1)
16
+ return m
@@ -0,0 +1,127 @@
1
+ # run.py
2
+ # parabellum run game live
3
+ # by: Noah Syrkis
4
+
5
+ # Noah Syrkis
6
+ import pygame
7
+ from jax import random
8
+ from functools import partial
9
+ import darkdetect
10
+ import jax.numpy as jnp
11
+ from chex import dataclass
12
+ import jaxmarl
13
+ from typing import Tuple, List, Dict, Optional
14
+ import parabellum as pb
15
+
16
+
17
+ # constants
18
+ fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
19
+ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
20
+
21
+
22
+ # types
23
+ State = jaxmarl.environments.smax.smax_env.State
24
+ Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
25
+ StateSeq = List[Tuple[jnp.ndarray, State, Action]]
26
+
27
+
28
+ @dataclass
29
+ class Control:
30
+ running: bool = True
31
+ paused: bool = False
32
+ click: Optional[Tuple[int, int]] = None
33
+
34
+
35
+ @dataclass
36
+ class Game:
37
+ clock: pygame.time.Clock
38
+ state: State
39
+ obs: Dict
40
+ state_seq: StateSeq
41
+ control: Control
42
+ env: pb.Parabellum
43
+ rng: random.PRNGKey
44
+
45
+
46
+ def handle_event(event, control_state):
47
+ """Handle pygame events."""
48
+ if event.type == pygame.QUIT:
49
+ control_state.running = False
50
+ if event.type == pygame.MOUSEBUTTONDOWN:
51
+ pos = pygame.mouse.get_pos()
52
+ control_state.click = pos
53
+ if event.type == pygame.MOUSEBUTTONUP:
54
+ control_state.click = None
55
+ if event.type == pygame.KEYDOWN: # any key press pauses
56
+ control_state.paused = not control_state.paused
57
+ return control_state
58
+
59
+
60
+ def control_fn(game):
61
+ """Handle pygame events."""
62
+ for event in pygame.event.get():
63
+ game.control = handle_event(event, game.control)
64
+ return game
65
+
66
+
67
+ def render_fn(screen, game):
68
+ """Render the game."""
69
+ if len(game.state_seq) < 3:
70
+ return game
71
+ for rng, state, action in env.expand_state_seq(game.state_seq[-2:])[-8:]:
72
+ screen.fill(bg)
73
+ if game.control.click is not None:
74
+ pygame.draw.circle(screen, "red", game.control.click, 10)
75
+ unit_positions = state.unit_positions
76
+ for pos in unit_positions:
77
+ pos = (pos / env.map_width * 800).tolist()
78
+ pygame.draw.circle(screen, fg, pos, 5)
79
+ pygame.display.flip()
80
+ game.clock.tick(24) # limits FPS to 24
81
+ return game
82
+
83
+
84
+ def step_fn(game):
85
+ """Step in parabellum."""
86
+ rng, act_rng, step_key = random.split(game.rng, 3)
87
+ act_key = random.split(act_rng, env.num_agents)
88
+ action = {
89
+ a: env.action_space(a).sample(act_key[i]) for i, a in enumerate(env.agents)
90
+ }
91
+ state_seq_entry = (step_key, game.state, action)
92
+ # append state_seq_entry to state_seq
93
+ game.state_seq.append(state_seq_entry)
94
+ obs, state, reward, done, info = env.step(step_key, game.state, action)
95
+ game.state = state
96
+ game.obs = obs
97
+ game.rng = rng
98
+ return game
99
+
100
+
101
+ # state
102
+ if __name__ == "__main__":
103
+ env = pb.Parabellum(pb.scenarios["default"])
104
+ pygame.init()
105
+ screen = pygame.display.set_mode((1000, 1000))
106
+ render = partial(render_fn, screen)
107
+ rng, key = random.split(random.PRNGKey(0))
108
+ obs, state = env.reset(key)
109
+ kwargs = dict(
110
+ control=Control(),
111
+ env=env,
112
+ rng=rng,
113
+ state_seq=[], # [(key, state, action)]
114
+ clock=pygame.time.Clock(),
115
+ state=state,
116
+ obs=obs,
117
+ )
118
+ game = Game(**kwargs)
119
+
120
+ while game.control.running:
121
+ game = control_fn(game)
122
+ game = game if game.control.paused else step_fn(game)
123
+ game = game if game.control.paused else render(game)
124
+
125
+ pygame.quit()
126
+
127
+