parabellum 0.0.0__py3-none-any.whl → 0.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,230 +0,0 @@
1
- """Visualizer for the Parabellum environment"""
2
-
3
- from tqdm import tqdm
4
- import jax.numpy as jnp
5
- import jax
6
- from jax import vmap
7
- from functools import partial
8
- import darkdetect
9
- import pygame
10
- from moviepy.editor import ImageSequenceClip
11
- from typing import Optional
12
- from jaxmarl.environments.multi_agent_env import MultiAgentEnv
13
- from jaxmarl.viz.visualizer import SMAXVisualizer
14
-
15
- # default dict
16
- from collections import defaultdict
17
-
18
-
19
- # constants
20
- action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
21
-
22
-
23
- class Visualizer(SMAXVisualizer):
24
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
25
- super().__init__(env, state_seq, reward_seq)
26
- # remove fig and ax from super
27
- self.fig, self.ax = None, None
28
- self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
29
- self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
30
- self.s = 1000
31
- self.scale = self.s / self.env.map_width
32
- self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
33
- self.bullet_seq = bullet_fn(self.env, self.state_seq)
34
-
35
- def render_agents(self, screen, state):
36
- time_tuple = zip(
37
- state.unit_positions,
38
- state.unit_teams,
39
- state.unit_types,
40
- state.unit_health,
41
- )
42
- for idx, (pos, team, kind, hp) in enumerate(time_tuple):
43
- face_col = self.fg if int(team.item()) == 0 else self.bg
44
- pos = tuple((pos * self.scale).tolist())
45
-
46
- # draw the agent
47
- if hp > 0:
48
- hp_frac = hp / self.env.unit_type_health[kind]
49
- unit_size = self.env.unit_type_radiuses[kind]
50
- radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
51
- pygame.draw.circle(screen, face_col, pos, radius)
52
- pygame.draw.circle(screen, self.fg, pos, radius, 1)
53
-
54
- # draw the sight range
55
- # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
56
- # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
57
-
58
- # draw attack range
59
- # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
60
- # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
61
- # work out which agents are being shot
62
-
63
- def render_action(self, screen, action):
64
- def coord_fn(idx, n, team):
65
- return (
66
- self.s / 20 if team == 0 else self.s - self.s / 20,
67
- # vertically centered so that n / 2 is above and below the center
68
- self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
69
- )
70
-
71
- for idx in range(self.env.num_allies):
72
- symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
73
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
74
- text = font.render(symb, True, self.fg)
75
- coord = coord_fn(idx, self.env.num_allies, 0)
76
- screen.blit(text, coord)
77
-
78
- for idx in range(self.env.num_enemies):
79
- symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
80
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
81
- text = font.render(symb, True, self.fg)
82
- coord = coord_fn(idx, self.env.num_enemies, 1)
83
- screen.blit(text, coord)
84
-
85
- def render_obstacles(self, screen):
86
- for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
87
- d = tuple(((c + d) * self.scale).tolist())
88
- c = tuple((c * self.scale).tolist())
89
- pygame.draw.line(screen, self.fg, c, d, 5)
90
-
91
- def render_bullets(self, screen, bullets, jdx):
92
- jdx += 1
93
- ally_bullets, enemy_bullets = bullets
94
- for source, target in ally_bullets:
95
- position = source + (target - source) * jdx / 8
96
- position *= self.scale
97
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
98
- for source, target in enemy_bullets:
99
- position = source + (target - source) * jdx / 8
100
- position *= self.scale
101
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
102
-
103
- def animate(self, save_fname: str = "parabellum.mp4"):
104
- if not self.have_expanded:
105
- self.expand_state_seq()
106
- frames = [] # frames for the video
107
- pygame.init() # initialize pygame
108
- for idx, (_, state, _) in tqdm(
109
- enumerate(self.state_seq), total=len(self.state_seq)
110
- ):
111
- screen = pygame.Surface(
112
- (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
113
- )
114
- screen.fill(self.bg) # fill the screen with the background color
115
-
116
- self.render_agents(screen, state) # render the agents
117
- self.render_action(screen, self.action_seq[idx // 8])
118
- self.render_obstacles(screen) # render the obstacles
119
-
120
- # bullets
121
- if idx < len(self.bullet_seq) * 8:
122
- bullets = self.bullet_seq[idx // 8]
123
- self.render_bullets(screen, bullets, idx % 8)
124
-
125
- # rotate the screen and append to frames
126
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
127
-
128
- # save the images
129
- clip = ImageSequenceClip(frames, fps=48)
130
- clip.write_videofile(save_fname, fps=48)
131
- # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
132
- pygame.quit()
133
-
134
- return clip
135
-
136
-
137
- # functions
138
- # bullet functions
139
- def dist_fn(env, pos): # computing the distances between all ally and enemy agents
140
- delta = pos[None, :, :] - pos[:, None, :]
141
- dist = jnp.sqrt((delta**2).sum(axis=2))
142
- dist = dist[: env.num_allies, env.num_allies :]
143
- return {"ally": dist, "enemy": dist.T}
144
-
145
-
146
- def range_fn(env, dists, ranges): # computing what targets are in range
147
- ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
148
- enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
149
- return {"ally": ally_range, "enemy": enemy_range}
150
-
151
-
152
- def target_fn(acts, in_range, team): # computing the one hot valid targets
153
- t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
154
- t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
155
- t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
156
- return t_attacks * in_range[team] # one hot valid targets
157
-
158
-
159
- def attack_fn(env, state_seq): # one hot attack list
160
- attacks = []
161
- for _, state, acts in state_seq:
162
- dists = dist_fn(env, state.unit_positions)
163
- ranges = env.unit_type_attack_ranges[state.unit_types]
164
- in_range = range_fn(env, dists, ranges)
165
- target = partial(target_fn, acts, in_range)
166
- attack = {"ally": target("ally"), "enemy": target("enemy")}
167
- attacks.append(attack)
168
- return attacks
169
-
170
-
171
- def bullet_fn(env, states):
172
- bullet_seq = []
173
- attack_seq = attack_fn(env, states)
174
-
175
- def aux_fn(team):
176
- bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
177
- # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
178
- return bullets
179
-
180
- state_zip = zip(states[:-1], states[1:])
181
- for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
182
- one_hot = attack_seq[i]
183
- ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
184
-
185
- ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
186
- enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
187
-
188
- enemy_bullets_source = state.unit_positions[
189
- enemy_bullets[:, 0] + env.num_allies
190
- ]
191
- ally_bullets_target = n_state.unit_positions[
192
- ally_bullets[:, 1] + env.num_allies
193
- ]
194
-
195
- ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
196
- enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
197
-
198
- bullet_seq.append((ally_bullets, enemy_bullets))
199
- return bullet_seq
200
-
201
-
202
- # test the visualizer
203
- if __name__ == "__main__":
204
- from parabellum import Parabellum, Scenario
205
- from jax import random, numpy as jnp
206
-
207
- s = Scenario(jnp.array([[16, 0]]),
208
- jnp.array([[0, 32]]) * 8,
209
- jnp.zeros((19,), dtype=jnp.uint8),
210
- 9,
211
- 10)
212
- env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
213
- rng, key = random.split(random.PRNGKey(0))
214
- obs, state = env.reset(key)
215
- state_seq = []
216
- for step in range(50):
217
- rng, key = random.split(rng)
218
- key_act = random.split(key, len(env.agents))
219
- actions = {
220
- agent: jnp.array(1)
221
- for i, agent in enumerate(env.agents)
222
- }
223
- state_seq.append((key, state, actions))
224
- rng, key_step = random.split(rng)
225
- obs, state, reward, done, infos = env.step(key_step, state, actions)
226
-
227
- vis = Visualizer(env, state_seq)
228
- vis.animate()
229
-
230
-
parabellum/vis.py DELETED
@@ -1,230 +0,0 @@
1
- """Visualizer for the Parabellum environment"""
2
-
3
- from tqdm import tqdm
4
- import jax.numpy as jnp
5
- import jax
6
- from jax import vmap
7
- from functools import partial
8
- import darkdetect
9
- import pygame
10
- from moviepy.editor import ImageSequenceClip
11
- from typing import Optional
12
- from jaxmarl.environments.multi_agent_env import MultiAgentEnv
13
- from jaxmarl.viz.visualizer import SMAXVisualizer
14
-
15
- # default dict
16
- from collections import defaultdict
17
-
18
-
19
- # constants
20
- action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
21
-
22
-
23
- class Visualizer(SMAXVisualizer):
24
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
25
- super().__init__(env, state_seq, reward_seq)
26
- # remove fig and ax from super
27
- self.fig, self.ax = None, None
28
- self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
29
- self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
30
- self.s = 1000
31
- self.scale = self.s / self.env.map_width
32
- self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
33
- self.bullet_seq = bullet_fn(self.env, self.state_seq)
34
-
35
- def render_agents(self, screen, state):
36
- time_tuple = zip(
37
- state.unit_positions,
38
- state.unit_teams,
39
- state.unit_types,
40
- state.unit_health,
41
- )
42
- for idx, (pos, team, kind, hp) in enumerate(time_tuple):
43
- face_col = self.fg if int(team.item()) == 0 else self.bg
44
- pos = tuple((pos * self.scale).tolist())
45
-
46
- # draw the agent
47
- if hp > 0:
48
- hp_frac = hp / self.env.unit_type_health[kind]
49
- unit_size = self.env.unit_type_radiuses[kind]
50
- radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
51
- pygame.draw.circle(screen, face_col, pos, radius)
52
- pygame.draw.circle(screen, self.fg, pos, radius, 1)
53
-
54
- # draw the sight range
55
- # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
56
- # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
57
-
58
- # draw attack range
59
- # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
60
- # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
61
- # work out which agents are being shot
62
-
63
- def render_action(self, screen, action):
64
- def coord_fn(idx, n, team):
65
- return (
66
- self.s / 20 if team == 0 else self.s - self.s / 20,
67
- # vertically centered so that n / 2 is above and below the center
68
- self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
69
- )
70
-
71
- for idx in range(self.env.num_allies):
72
- symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
73
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
74
- text = font.render(symb, True, self.fg)
75
- coord = coord_fn(idx, self.env.num_allies, 0)
76
- screen.blit(text, coord)
77
-
78
- for idx in range(self.env.num_enemies):
79
- symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
80
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
81
- text = font.render(symb, True, self.fg)
82
- coord = coord_fn(idx, self.env.num_enemies, 1)
83
- screen.blit(text, coord)
84
-
85
- def render_obstacles(self, screen):
86
- for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
87
- d = tuple(((c + d) * self.scale).tolist())
88
- c = tuple((c * self.scale).tolist())
89
- pygame.draw.line(screen, self.fg, c, d, 5)
90
-
91
- def render_bullets(self, screen, bullets, jdx):
92
- jdx += 1
93
- ally_bullets, enemy_bullets = bullets
94
- for source, target in ally_bullets:
95
- position = source + (target - source) * jdx / 8
96
- position *= self.scale
97
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
98
- for source, target in enemy_bullets:
99
- position = source + (target - source) * jdx / 8
100
- position *= self.scale
101
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
102
-
103
- def animate(self, save_fname: str = "parabellum.mp4"):
104
- if not self.have_expanded:
105
- self.expand_state_seq()
106
- frames = [] # frames for the video
107
- pygame.init() # initialize pygame
108
- for idx, (_, state, _) in tqdm(
109
- enumerate(self.state_seq), total=len(self.state_seq)
110
- ):
111
- screen = pygame.Surface(
112
- (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
113
- )
114
- screen.fill(self.bg) # fill the screen with the background color
115
-
116
- self.render_agents(screen, state) # render the agents
117
- self.render_action(screen, self.action_seq[idx // 8])
118
- self.render_obstacles(screen) # render the obstacles
119
-
120
- # bullets
121
- if idx < len(self.bullet_seq) * 8:
122
- bullets = self.bullet_seq[idx // 8]
123
- self.render_bullets(screen, bullets, idx % 8)
124
-
125
- # rotate the screen and append to frames
126
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
127
-
128
- # save the images
129
- clip = ImageSequenceClip(frames, fps=48)
130
- clip.write_videofile(save_fname, fps=48)
131
- # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
132
- pygame.quit()
133
-
134
- return clip
135
-
136
-
137
- # functions
138
- # bullet functions
139
- def dist_fn(env, pos): # computing the distances between all ally and enemy agents
140
- delta = pos[None, :, :] - pos[:, None, :]
141
- dist = jnp.sqrt((delta**2).sum(axis=2))
142
- dist = dist[: env.num_allies, env.num_allies :]
143
- return {"ally": dist, "enemy": dist.T}
144
-
145
-
146
- def range_fn(env, dists, ranges): # computing what targets are in range
147
- ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
148
- enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
149
- return {"ally": ally_range, "enemy": enemy_range}
150
-
151
-
152
- def target_fn(acts, in_range, team): # computing the one hot valid targets
153
- t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
154
- t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
155
- t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
156
- return t_attacks * in_range[team] # one hot valid targets
157
-
158
-
159
- def attack_fn(env, state_seq): # one hot attack list
160
- attacks = []
161
- for _, state, acts in state_seq:
162
- dists = dist_fn(env, state.unit_positions)
163
- ranges = env.unit_type_attack_ranges[state.unit_types]
164
- in_range = range_fn(env, dists, ranges)
165
- target = partial(target_fn, acts, in_range)
166
- attack = {"ally": target("ally"), "enemy": target("enemy")}
167
- attacks.append(attack)
168
- return attacks
169
-
170
-
171
- def bullet_fn(env, states):
172
- bullet_seq = []
173
- attack_seq = attack_fn(env, states)
174
-
175
- def aux_fn(team):
176
- bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
177
- # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
178
- return bullets
179
-
180
- state_zip = zip(states[:-1], states[1:])
181
- for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
182
- one_hot = attack_seq[i]
183
- ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
184
-
185
- ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
186
- enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
187
-
188
- enemy_bullets_source = state.unit_positions[
189
- enemy_bullets[:, 0] + env.num_allies
190
- ]
191
- ally_bullets_target = n_state.unit_positions[
192
- ally_bullets[:, 1] + env.num_allies
193
- ]
194
-
195
- ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
196
- enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
197
-
198
- bullet_seq.append((ally_bullets, enemy_bullets))
199
- return bullet_seq
200
-
201
-
202
- # test the visualizer
203
- if __name__ == "__main__":
204
- from parabellum import Parabellum, Scenario
205
- from jax import random, numpy as jnp
206
-
207
- s = Scenario(jnp.array([[16, 0]]),
208
- jnp.array([[0, 32]]) * 8,
209
- jnp.zeros((19,), dtype=jnp.uint8),
210
- 9,
211
- 10)
212
- env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
213
- rng, key = random.split(random.PRNGKey(0))
214
- obs, state = env.reset(key)
215
- state_seq = []
216
- for step in range(50):
217
- rng, key = random.split(rng)
218
- key_act = random.split(key, len(env.agents))
219
- actions = {
220
- agent: jnp.array(1)
221
- for i, agent in enumerate(env.agents)
222
- }
223
- state_seq.append((key, state, actions))
224
- rng, key_step = random.split(rng)
225
- obs, state, reward, done, infos = env.step(key_step, state, actions)
226
-
227
- vis = Visualizer(env, state_seq)
228
- vis.animate()
229
-
230
-
@@ -1,55 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: parabellum
3
- Version: 0.0.0
4
- Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
- Author: Noah Syrkis
9
- Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
11
- Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
- Requires-Dist: jaxmarl (==0.0.3)
17
- Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
18
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
19
- Requires-Dist: poetry (>=1.8.3,<2.0.0)
20
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
21
- Project-URL: Repository, https://github.com/syrkis/parabellum
22
- Description-Content-Type: text/markdown
23
-
24
- # parabellum
25
-
26
- Parabellum is an ultra-scalable, high-performance warfare simulation engine.
27
- It is based on JaxMARL's SMAX environment, but has been heavily modified to
28
- support a wide range of new features and improvements.
29
-
30
- ## Installation
31
-
32
- Install through PyPI:
33
-
34
- ```bash
35
- pip install parabellum
36
- ```
37
-
38
- ## Usage
39
-
40
- ```python
41
- import parabellum as pb
42
- ```
43
-
44
- ## TODO
45
-
46
- - [ ] Parallel pygame vis
47
- - [ ] Color for health?
48
- - [ ] Add the ability to see ongoing game.
49
- - [ ] Bug test friendly fire.
50
- - [ ] Start sim from arbitrary state.
51
- - [ ] Save when the episode ends in some state/obs variable
52
- - [ ] Look for the source of the bug when using more Allies than Enemies
53
- - [ ] Y inversed axis for parabellum visualization
54
- - [ ] Units see through obstacles?
55
-
@@ -1,9 +0,0 @@
1
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
2
- parabellum/.ipynb_checkpoints/env-checkpoint.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
3
- parabellum/.ipynb_checkpoints/vis-checkpoint.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
4
- parabellum/__init__.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
5
- parabellum/env.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
6
- parabellum/vis.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
7
- parabellum-0.0.0.dist-info/METADATA,sha256=cV1VBjjoFLEUmDebGlRpARDdJULpNQ6JspYJ5dqU5ns,1588
8
- parabellum-0.0.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
9
- parabellum-0.0.0.dist-info/RECORD,,