parabellum 0.2.26__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +9 -5
- parabellum/aid.py +22 -0
- parabellum/env.py +269 -118
- parabellum/geo.py +130 -0
- parabellum/pcg.py +53 -0
- parabellum/run.py +1 -1
- parabellum/terrain_db.py +117 -0
- parabellum/tps.py +17 -0
- parabellum/vis.py +48 -26
- {parabellum-0.2.26.dist-info → parabellum-0.3.1.dist-info}/METADATA +9 -6
- parabellum-0.3.1.dist-info/RECORD +13 -0
- parabellum/map.py +0 -100
- parabellum-0.2.26.dist-info/RECORD +0 -10
- {parabellum-0.2.26.dist-info → parabellum-0.3.1.dist-info}/WHEEL +0 -0
parabellum/__init__.py
CHANGED
@@ -1,18 +1,22 @@
|
|
1
|
-
from .env import Environment, Scenario,
|
1
|
+
from .env import Environment, Scenario, make_scenario, State
|
2
2
|
from .vis import Visualizer, Skin
|
3
|
-
from .map import terrain_fn
|
4
3
|
from .gun import bullet_fn
|
5
|
-
|
4
|
+
from . import vis
|
5
|
+
from . import terrain_db
|
6
|
+
from . import env
|
7
|
+
from . import tps
|
6
8
|
# from .run import run
|
7
9
|
|
8
10
|
__all__ = [
|
11
|
+
"env",
|
12
|
+
"terrain_db",
|
13
|
+
"vis",
|
14
|
+
"tps",
|
9
15
|
"Environment",
|
10
16
|
"Scenario",
|
11
|
-
"scenarios",
|
12
17
|
"make_scenario",
|
13
18
|
"State",
|
14
19
|
"Visualizer",
|
15
20
|
"Skin",
|
16
|
-
"terrain_fn",
|
17
21
|
"bullet_fn",
|
18
22
|
]
|
parabellum/aid.py
CHANGED
@@ -3,3 +3,25 @@
|
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# imports
|
6
|
+
import os
|
7
|
+
from collections import namedtuple
|
8
|
+
from typing import Tuple
|
9
|
+
import cartopy.crs as ccrs
|
10
|
+
|
11
|
+
# types
|
12
|
+
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
13
|
+
|
14
|
+
|
15
|
+
# coordinate function
|
16
|
+
def to_mercator(bbox: BBox) -> BBox:
|
17
|
+
proj = ccrs.Mercator()
|
18
|
+
west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
|
19
|
+
east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
|
20
|
+
return BBox(north=north, south=south, east=east, west=west)
|
21
|
+
|
22
|
+
|
23
|
+
def to_platecarree(bbox: BBox) -> BBox:
|
24
|
+
proj = ccrs.PlateCarree()
|
25
|
+
west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
|
26
|
+
east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
|
27
|
+
return BBox(north=north, south=south, east=east, west=west)
|
parabellum/env.py
CHANGED
@@ -2,15 +2,16 @@
|
|
2
2
|
|
3
3
|
import jax.numpy as jnp
|
4
4
|
import jax
|
5
|
-
import
|
6
|
-
from jax import random, Array
|
7
|
-
from jax import jit
|
5
|
+
from jax import random, Array, vmap, jit
|
8
6
|
from flax.struct import dataclass
|
9
7
|
import chex
|
10
|
-
from jax import vmap
|
11
8
|
from jaxmarl.environments.smax.smax_env import SMAX
|
9
|
+
|
10
|
+
from math import ceil
|
11
|
+
|
12
12
|
from typing import Tuple, Dict, cast
|
13
13
|
from functools import partial
|
14
|
+
from parabellum import tps, geo, terrain_db
|
14
15
|
|
15
16
|
|
16
17
|
@dataclass
|
@@ -18,8 +19,8 @@ class Scenario:
|
|
18
19
|
"""Parabellum scenario"""
|
19
20
|
|
20
21
|
place: str
|
21
|
-
|
22
|
-
unit_starting_sectors: jnp.ndarray
|
22
|
+
terrain: tps.Terrain
|
23
|
+
unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
|
23
24
|
unit_types: chex.Array
|
24
25
|
num_allies: int
|
25
26
|
num_enemies: int
|
@@ -27,9 +28,11 @@ class Scenario:
|
|
27
28
|
smacv2_position_generation: bool = False
|
28
29
|
smacv2_unit_type_generation: bool = False
|
29
30
|
|
31
|
+
|
30
32
|
@dataclass
|
31
33
|
class State:
|
32
|
-
|
34
|
+
# terrain: Array
|
35
|
+
unit_positions: Array # fsfds
|
33
36
|
unit_alive: Array
|
34
37
|
unit_teams: Array
|
35
38
|
unit_health: Array
|
@@ -41,93 +44,156 @@ class State:
|
|
41
44
|
terminal: bool
|
42
45
|
|
43
46
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
47
|
+
def make_scenario(
|
48
|
+
place,
|
49
|
+
size,
|
50
|
+
unit_starting_sectors,
|
51
|
+
allies_type,
|
52
|
+
n_allies,
|
53
|
+
enemies_type,
|
54
|
+
n_enemies,
|
55
|
+
):
|
56
|
+
if place in terrain_db.db:
|
57
|
+
terrain = terrain_db.make_terrain(terrain_db.db[place], size)
|
58
|
+
else:
|
59
|
+
terrain = geo.geography_fn(place, size)
|
60
|
+
if type(unit_starting_sectors) == list:
|
61
|
+
default_sector = [
|
62
|
+
0,
|
63
|
+
0,
|
64
|
+
size,
|
65
|
+
size,
|
66
|
+
] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
|
67
|
+
correct_unit_starting_sectors = []
|
68
|
+
for i in range(n_allies + n_enemies):
|
69
|
+
selected_sector = None
|
70
|
+
for unit_ids, sector in unit_starting_sectors:
|
71
|
+
if i in unit_ids:
|
72
|
+
selected_sector = sector
|
73
|
+
if selected_sector is None:
|
74
|
+
selected_sector = default_sector
|
75
|
+
correct_unit_starting_sectors.append(selected_sector)
|
76
|
+
unit_starting_sectors = correct_unit_starting_sectors
|
58
77
|
if type(allies_type) == int:
|
59
78
|
allies = [allies_type] * n_allies
|
60
79
|
else:
|
61
|
-
assert
|
80
|
+
assert len(allies_type) == n_allies
|
62
81
|
allies = allies_type
|
63
82
|
|
64
83
|
if type(enemies_type) == int:
|
65
84
|
enemies = [enemies_type] * n_enemies
|
66
85
|
else:
|
67
|
-
assert
|
86
|
+
assert len(enemies_type) == n_enemies
|
68
87
|
enemies = enemies_type
|
69
88
|
unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
|
70
|
-
return Scenario(
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
# select n random (x, y)-coords where sector == True
|
79
|
-
idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
|
80
|
-
coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
|
81
|
-
|
82
|
-
return coords + noise + offset
|
89
|
+
return Scenario(
|
90
|
+
place,
|
91
|
+
terrain,
|
92
|
+
unit_starting_sectors, # type: ignore
|
93
|
+
unit_types,
|
94
|
+
n_allies,
|
95
|
+
n_enemies,
|
96
|
+
)
|
83
97
|
|
84
98
|
|
85
|
-
def
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
#
|
92
|
-
|
99
|
+
def scenario_fn(place):
|
100
|
+
# scenario function for Noah, cos the one above is confusing
|
101
|
+
terrain = geo.geography_fn(place)
|
102
|
+
num_allies = 10
|
103
|
+
num_enemies = 10
|
104
|
+
unit_types = jnp.array([0] * num_allies + [1] * num_enemies, dtype=jnp.uint8)
|
105
|
+
# start units in default sectors
|
106
|
+
unit_starting_sectors = jnp.array([[0, 0, 1, 1]] * (num_allies + num_enemies))
|
107
|
+
return Scenario(
|
108
|
+
place=place,
|
109
|
+
terrain=terrain,
|
110
|
+
unit_starting_sectors=unit_starting_sectors,
|
111
|
+
unit_types=unit_types,
|
112
|
+
num_allies=num_allies,
|
113
|
+
num_enemies=num_enemies,
|
114
|
+
)
|
93
115
|
|
94
116
|
|
95
|
-
def
|
96
|
-
"""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
117
|
+
def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
|
118
|
+
"""Spawns n agents on a map."""
|
119
|
+
spawn_positions = []
|
120
|
+
for sector in units_spawning_sectors:
|
121
|
+
rng, key_start, key_noise = random.split(rng, 3)
|
122
|
+
noise = 0.25 + random.uniform(key_noise, (2,)) * 0.5
|
123
|
+
idx = random.choice(key_start, sector[0].shape[0])
|
124
|
+
coord = jnp.array([sector[0][idx], sector[1][idx]])
|
125
|
+
spawn_positions.append(coord + noise)
|
126
|
+
return jnp.array(spawn_positions, dtype=jnp.float32)
|
127
|
+
|
128
|
+
|
129
|
+
def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
|
130
|
+
"""
|
131
|
+
sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
|
132
|
+
"""
|
133
|
+
width, height = invalid_spawn_areas.shape
|
134
|
+
spawning_sectors = []
|
135
|
+
for sector in sectors:
|
136
|
+
coordx, coordy = (
|
137
|
+
jnp.array(sector[0] * width, dtype=jnp.int32),
|
138
|
+
jnp.array(sector[1] * height, dtype=jnp.int32),
|
139
|
+
)
|
140
|
+
sector = (
|
141
|
+
invalid_spawn_areas[
|
142
|
+
coordx : coordx + ceil(sector[2] * width),
|
143
|
+
coordy : coordy + ceil(sector[3] * height),
|
144
|
+
]
|
145
|
+
== 0
|
146
|
+
)
|
147
|
+
valid = jnp.nonzero(sector)
|
148
|
+
if valid[0].shape[0] == 0:
|
149
|
+
raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
|
150
|
+
spawning_sectors.append(
|
151
|
+
jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1))
|
152
|
+
)
|
153
|
+
return spawning_sectors
|
103
154
|
|
104
155
|
|
105
156
|
class Environment(SMAX):
|
106
157
|
def __init__(self, scenario: Scenario, **kwargs):
|
107
|
-
map_height, map_width = scenario.
|
158
|
+
map_height, map_width = scenario.terrain.building.shape
|
108
159
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
160
|
+
if "unit_type_pushable" in kwargs:
|
161
|
+
self.unit_type_pushable = kwargs["unit_type_pushable"]
|
162
|
+
del kwargs["unit_type_pushable"]
|
163
|
+
else:
|
164
|
+
self.unit_type_pushable = jnp.array([1, 1, 0, 0, 0, 1])
|
165
|
+
if "reset_when_done" in kwargs:
|
166
|
+
self.reset_when_done = kwargs["reset_when_done"]
|
167
|
+
del kwargs["reset_when_done"]
|
168
|
+
else:
|
169
|
+
self.reset_when_done = True
|
109
170
|
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
110
|
-
self.
|
171
|
+
self.terrain = scenario.terrain
|
111
172
|
self.unit_starting_sectors = scenario.unit_starting_sectors
|
112
173
|
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
113
174
|
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
114
175
|
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
115
176
|
self.scenario = scenario
|
116
|
-
self.unit_type_velocities=
|
177
|
+
self.unit_type_velocities = (
|
178
|
+
jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15]) / 2.5
|
179
|
+
if "unit_type_velocities" not in kwargs
|
180
|
+
else kwargs["unit_type_velocities"]
|
181
|
+
)
|
117
182
|
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
118
183
|
self.max_steps = 200
|
119
|
-
self._push_units_away = lambda state, firmness
|
120
|
-
self.
|
121
|
-
|
122
|
-
|
184
|
+
self._push_units_away = lambda state, firmness=1: state # overwrite push units
|
185
|
+
self.spawning_sectors = sectors_fn(
|
186
|
+
self.unit_starting_sectors,
|
187
|
+
scenario.terrain.building + scenario.terrain.water,
|
188
|
+
)
|
189
|
+
self.resolution = (
|
190
|
+
jnp.array(jnp.max(self.unit_type_sight_ranges), dtype=jnp.int32) * 2
|
191
|
+
)
|
192
|
+
self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, 1))
|
123
193
|
|
124
|
-
|
125
|
-
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
194
|
+
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
|
126
195
|
"""Environment-specific reset."""
|
127
|
-
|
128
|
-
team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
|
129
|
-
team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
|
130
|
-
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
196
|
+
unit_positions = spawn_fn(rng, self.spawning_sectors)
|
131
197
|
unit_teams = jnp.zeros((self.num_agents,))
|
132
198
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
133
199
|
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
@@ -145,18 +211,71 @@ class Environment(SMAX):
|
|
145
211
|
time=0,
|
146
212
|
terminal=False,
|
147
213
|
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
214
|
+
# terrain=self.terrain,
|
148
215
|
)
|
149
|
-
state = self._push_units_away(state) # type: ignore
|
216
|
+
state = self._push_units_away(state) # type: ignore could be slow
|
150
217
|
obs = self.get_obs(state)
|
218
|
+
# remove world_state from obs
|
151
219
|
world_state = self.get_world_state(state)
|
152
|
-
|
220
|
+
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
153
221
|
return obs, state
|
154
222
|
|
155
|
-
|
156
|
-
obs, state, rewards, dones, infos = super().step_env(rng, state, action)
|
223
|
+
# def step_env(self, rng, state: State, action: Array): # type: ignore
|
224
|
+
# obs, state, rewards, dones, infos = super().step_env(rng, state, action)
|
157
225
|
# delete world_state from obs
|
158
|
-
obs.pop("world_state")
|
159
|
-
|
226
|
+
# obs.pop("world_state")
|
227
|
+
# if not self.reset_when_done:
|
228
|
+
# for key in dones.keys():
|
229
|
+
# dones[key] = False
|
230
|
+
# return obs, state, rewards, dones, infos
|
231
|
+
|
232
|
+
def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
|
233
|
+
"""Applies observation function to state."""
|
234
|
+
|
235
|
+
def get_features(i, j):
|
236
|
+
"""Get features of unit j as seen from unit i"""
|
237
|
+
# Can just keep them symmetrical for now.
|
238
|
+
# j here means 'the jth unit that is not i'
|
239
|
+
# The observation is such that allies are always first
|
240
|
+
# so for units in the second team we count in reverse.
|
241
|
+
j = jax.lax.cond(
|
242
|
+
i < self.num_allies, lambda: j, lambda: self.num_agents - j - 1
|
243
|
+
)
|
244
|
+
offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
|
245
|
+
j_idx = jax.lax.cond(
|
246
|
+
((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
|
247
|
+
lambda: j,
|
248
|
+
lambda: j + offset,
|
249
|
+
)
|
250
|
+
empty_features = jnp.zeros(shape=(len(self.unit_features),))
|
251
|
+
features = self._observe_features(state, i, j_idx)
|
252
|
+
visible = (
|
253
|
+
jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
|
254
|
+
< self.unit_type_sight_ranges[state.unit_types[i]]
|
255
|
+
)
|
256
|
+
return jax.lax.cond(
|
257
|
+
visible
|
258
|
+
& state.unit_alive[i]
|
259
|
+
& state.unit_alive[j_idx]
|
260
|
+
& self.has_line_of_sight(
|
261
|
+
state.unit_positions[j_idx],
|
262
|
+
state.unit_positions[i],
|
263
|
+
self.terrain.building + self.terrain.forest,
|
264
|
+
),
|
265
|
+
lambda: features,
|
266
|
+
lambda: empty_features,
|
267
|
+
)
|
268
|
+
|
269
|
+
get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
|
270
|
+
get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
|
271
|
+
other_unit_obs = get_all_features(
|
272
|
+
jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
|
273
|
+
)
|
274
|
+
other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
|
275
|
+
get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
|
276
|
+
own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
|
277
|
+
obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
|
278
|
+
return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
|
160
279
|
|
161
280
|
def _our_push_units_away(
|
162
281
|
self, pos, unit_types, firmness: float = 1.0
|
@@ -176,7 +295,19 @@ class Environment(SMAX):
|
|
176
295
|
pos
|
177
296
|
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
178
297
|
)
|
179
|
-
return
|
298
|
+
return jnp.where(
|
299
|
+
self.unit_type_pushable[unit_types][:, None], unit_positions, pos
|
300
|
+
)
|
301
|
+
|
302
|
+
def has_line_of_sight(self, source, target, raster_input):
|
303
|
+
# suppose that the target is in sight_range of source, otherwise the line of sight might miss some cells
|
304
|
+
cells = jnp.array(
|
305
|
+
source[:, jnp.newaxis] * self.t + (1 - self.t) * target[:, jnp.newaxis],
|
306
|
+
dtype=jnp.int32,
|
307
|
+
)
|
308
|
+
mask = jnp.zeros(raster_input.shape).at[cells[0, :], cells[1, :]].set(1)
|
309
|
+
flag = ~jnp.any(jnp.logical_and(mask, raster_input))
|
310
|
+
return flag
|
180
311
|
|
181
312
|
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
182
313
|
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
@@ -185,24 +316,15 @@ class Environment(SMAX):
|
|
185
316
|
state: State,
|
186
317
|
actions: Tuple[chex.Array, chex.Array],
|
187
318
|
) -> State:
|
188
|
-
|
189
|
-
def intersect_fn(pos, new_pos, obs, obs_end):
|
190
|
-
d1 = jnp.cross(obs - pos, new_pos - pos)
|
191
|
-
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
192
|
-
d3 = jnp.cross(pos - obs, obs_end - obs)
|
193
|
-
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
194
|
-
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
195
|
-
|
196
|
-
def raster_crossing(pos, new_pos):
|
319
|
+
def raster_crossing(pos, new_pos, mask: jnp.ndarray):
|
197
320
|
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
198
|
-
raster = jnp.copy(self.terrain_raster)
|
199
321
|
minimum = jnp.minimum(pos, new_pos)
|
200
322
|
maximum = jnp.maximum(pos, new_pos)
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
return jnp.any(
|
323
|
+
mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask.T, 0).T
|
324
|
+
mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask.T, 0).T
|
325
|
+
mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask, 0)
|
326
|
+
mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask, 0)
|
327
|
+
return jnp.any(mask)
|
206
328
|
|
207
329
|
def update_position(idx, vec):
|
208
330
|
# Compute the movements slightly strangely.
|
@@ -219,13 +341,17 @@ class Environment(SMAX):
|
|
219
341
|
)
|
220
342
|
# avoid going out of bounds
|
221
343
|
new_pos = jnp.maximum(
|
222
|
-
jnp.minimum(
|
344
|
+
jnp.minimum(
|
345
|
+
new_pos, jnp.array([self.map_width - 1, self.map_height - 1])
|
346
|
+
),
|
223
347
|
jnp.zeros((2,)),
|
224
348
|
)
|
225
349
|
|
226
350
|
#######################################################################
|
227
351
|
############################################ avoid going into obstacles
|
228
|
-
clash = raster_crossing(
|
352
|
+
clash = raster_crossing(
|
353
|
+
pos, new_pos, self.terrain.building + self.terrain.water
|
354
|
+
)
|
229
355
|
new_pos = jnp.where(clash, pos, new_pos)
|
230
356
|
|
231
357
|
#######################################################################
|
@@ -263,14 +389,11 @@ class Environment(SMAX):
|
|
263
389
|
attacked_idx = jax.lax.select(
|
264
390
|
action < self.num_movement_actions, idx, attacked_idx
|
265
391
|
)
|
266
|
-
|
392
|
+
distance = jnp.linalg.norm(
|
393
|
+
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
394
|
+
)
|
267
395
|
attack_valid = (
|
268
|
-
(
|
269
|
-
jnp.linalg.norm(
|
270
|
-
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
271
|
-
)
|
272
|
-
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
273
|
-
)
|
396
|
+
(distance <= self.unit_type_attack_ranges[state.unit_types[idx]])
|
274
397
|
& state.unit_alive[idx]
|
275
398
|
& state.unit_alive[attacked_idx]
|
276
399
|
)
|
@@ -281,21 +404,28 @@ class Environment(SMAX):
|
|
281
404
|
-self.unit_type_attacks[state.unit_types[idx]],
|
282
405
|
0.0,
|
283
406
|
)
|
407
|
+
health_diff = jnp.where(
|
408
|
+
state.unit_types[idx] == 1,
|
409
|
+
health_diff
|
410
|
+
* distance
|
411
|
+
/ self.unit_type_attack_ranges[state.unit_types[idx]],
|
412
|
+
health_diff,
|
413
|
+
)
|
284
414
|
# design choice based on the pysc2 randomness details.
|
285
415
|
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
286
416
|
|
287
417
|
#########################################################
|
288
418
|
############################### Add bystander health diff
|
289
419
|
|
290
|
-
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
291
|
-
bystander_valid = (
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
)
|
296
|
-
bystander_health_diff = (
|
297
|
-
|
298
|
-
)
|
420
|
+
# bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
421
|
+
# bystander_valid = (
|
422
|
+
# jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
423
|
+
# .astype(jnp.bool_) # type: ignore
|
424
|
+
# .astype(jnp.float32)
|
425
|
+
# )
|
426
|
+
# bystander_health_diff = (
|
427
|
+
# bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
428
|
+
# )
|
299
429
|
|
300
430
|
#########################################################
|
301
431
|
#########################################################
|
@@ -319,29 +449,34 @@ class Environment(SMAX):
|
|
319
449
|
health_diff,
|
320
450
|
attacked_idx,
|
321
451
|
cooldown_diff,
|
322
|
-
(bystander_health_diff, bystander_idxs),
|
452
|
+
# (bystander_health_diff, bystander_idxs),
|
323
453
|
)
|
324
454
|
|
325
455
|
def perform_agent_action(idx, action, key):
|
326
456
|
movement_action, attack_action = action
|
327
457
|
new_pos = update_position(idx, movement_action)
|
328
|
-
health_diff, attacked_idxes, cooldown_diff
|
329
|
-
|
458
|
+
health_diff, attacked_idxes, cooldown_diff = update_agent_health(
|
459
|
+
idx, attack_action, key
|
330
460
|
)
|
331
461
|
|
332
|
-
return new_pos, (health_diff, attacked_idxes), cooldown_diff
|
462
|
+
return new_pos, (health_diff, attacked_idxes), cooldown_diff
|
333
463
|
|
334
464
|
keys = jax.random.split(key, num=self.num_agents)
|
335
|
-
pos, (health_diff, attacked_idxes), cooldown_diff
|
465
|
+
pos, (health_diff, attacked_idxes), cooldown_diff = jax.vmap(
|
336
466
|
perform_agent_action
|
337
467
|
)(jnp.arange(self.num_agents), actions, keys)
|
338
468
|
|
339
469
|
# units push each other
|
340
470
|
new_pos = self._our_push_units_away(pos, state.unit_types)
|
341
|
-
clash = jax.vmap(raster_crossing
|
471
|
+
clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(
|
472
|
+
pos, new_pos, self.terrain.building + self.terrain.water
|
473
|
+
)
|
342
474
|
pos = jax.vmap(jnp.where)(clash, pos, new_pos)
|
343
475
|
# avoid going out of bounds
|
344
|
-
pos = jnp.maximum(
|
476
|
+
pos = jnp.maximum(
|
477
|
+
jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
|
478
|
+
jnp.zeros((2,)),
|
479
|
+
)
|
345
480
|
|
346
481
|
# Multiple enemies can attack the same unit.
|
347
482
|
# We have `(health_diff, attacked_idx)` pairs.
|
@@ -370,8 +505,8 @@ class Environment(SMAX):
|
|
370
505
|
#########################################################
|
371
506
|
############################ subtracting bystander health
|
372
507
|
|
373
|
-
_, bystander_health_diff = bystander
|
374
|
-
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
508
|
+
# _, bystander_health_diff = bystander
|
509
|
+
# unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
375
510
|
|
376
511
|
#########################################################
|
377
512
|
#########################################################
|
@@ -389,15 +524,30 @@ class Environment(SMAX):
|
|
389
524
|
if __name__ == "__main__":
|
390
525
|
n_envs = 4
|
391
526
|
|
392
|
-
|
527
|
+
n_allies = 10
|
528
|
+
scenario_kwargs = {
|
529
|
+
"allies_type": 0,
|
530
|
+
"n_allies": n_allies,
|
531
|
+
"enemies_type": 0,
|
532
|
+
"n_enemies": n_allies,
|
533
|
+
"place": "Vesterbro, Copenhagen, Denmark",
|
534
|
+
"size": 100,
|
535
|
+
"unit_starting_sectors": [
|
536
|
+
([i for i in range(n_allies)], [0.0, 0.45, 0.1, 0.1]),
|
537
|
+
([n_allies + i for i in range(n_allies)], [0.8, 0.5, 0.1, 0.1]),
|
538
|
+
],
|
539
|
+
}
|
540
|
+
scenario = make_scenario(**scenario_kwargs)
|
541
|
+
env = Environment(scenario)
|
393
542
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
394
543
|
reset_key = random.split(reset_rng, n_envs)
|
395
544
|
obs, state = vmap(env.reset)(reset_key)
|
396
545
|
state_seq = []
|
397
546
|
|
398
|
-
|
399
|
-
exit()
|
547
|
+
import time
|
400
548
|
|
549
|
+
step = vmap(jit(env.step))
|
550
|
+
tic = time.time()
|
401
551
|
for i in range(10):
|
402
552
|
rng, act_rng, step_rng = random.split(rng, 3)
|
403
553
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
@@ -407,4 +557,5 @@ if __name__ == "__main__":
|
|
407
557
|
}
|
408
558
|
step_key = random.split(step_rng, n_envs)
|
409
559
|
state_seq.append((step_key, state, act))
|
410
|
-
obs, state, reward, done, infos =
|
560
|
+
obs, state, reward, done, infos = step(step_key, state, act)
|
561
|
+
tic = time.time()
|
parabellum/geo.py
ADDED
@@ -0,0 +1,130 @@
|
|
1
|
+
# %% geo.py
|
2
|
+
# script for geospatial level generation
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from parabellum import tps
|
7
|
+
import rasterio
|
8
|
+
from rasterio import features, transform
|
9
|
+
from geopy.geocoders import Nominatim
|
10
|
+
from geopy.distance import distance
|
11
|
+
import contextily as cx
|
12
|
+
import jax.numpy as jnp
|
13
|
+
import cartopy.crs as ccrs
|
14
|
+
from jaxtyping import Array
|
15
|
+
import numpy as np
|
16
|
+
from shapely import box
|
17
|
+
import osmnx as ox
|
18
|
+
import geopandas as gpd
|
19
|
+
from collections import namedtuple
|
20
|
+
from typing import Tuple
|
21
|
+
import matplotlib.pyplot as plt
|
22
|
+
import seaborn as sns
|
23
|
+
import os
|
24
|
+
from jax.scipy.signal import convolve
|
25
|
+
|
26
|
+
# %% Types
|
27
|
+
Coords = Tuple[float, float]
|
28
|
+
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
29
|
+
|
30
|
+
# %% Constants
|
31
|
+
provider = cx.providers.Stadia.StamenTerrain( # type: ignore
|
32
|
+
api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
|
33
|
+
)
|
34
|
+
provider["url"] = provider["url"] + "?api_key={api_key}"
|
35
|
+
tags = {
|
36
|
+
"building": True,
|
37
|
+
"water": True,
|
38
|
+
"highway": True,
|
39
|
+
"landuse": [
|
40
|
+
"grass",
|
41
|
+
"forest",
|
42
|
+
"flowerbed",
|
43
|
+
"greenfield",
|
44
|
+
"village_green",
|
45
|
+
"recreation_ground",
|
46
|
+
],
|
47
|
+
"leisure": "garden",
|
48
|
+
} # "road": True}
|
49
|
+
|
50
|
+
|
51
|
+
# %% Coordinate function
|
52
|
+
def get_coordinates(place: str) -> Coords:
|
53
|
+
geolocator = Nominatim(user_agent="parabellum")
|
54
|
+
point = geolocator.geocode(place)
|
55
|
+
return point.latitude, point.longitude # type: ignore
|
56
|
+
|
57
|
+
|
58
|
+
def get_bbox(place: str, buffer) -> BBox:
|
59
|
+
"""Get bounding box from place name in crs 4326."""
|
60
|
+
coords = get_coordinates(place)
|
61
|
+
north = distance(meters=buffer).destination(coords, bearing=0).latitude
|
62
|
+
south = distance(meters=buffer).destination(coords, bearing=180).latitude
|
63
|
+
east = distance(meters=buffer).destination(coords, bearing=90).longitude
|
64
|
+
west = distance(meters=buffer).destination(coords, bearing=270).longitude
|
65
|
+
return BBox(north, south, east, west)
|
66
|
+
|
67
|
+
|
68
|
+
def basemap_fn(bbox: BBox, gdf) -> Array:
|
69
|
+
fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
|
70
|
+
gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
|
71
|
+
cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
|
72
|
+
bbox = gdf.total_bounds
|
73
|
+
ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
|
74
|
+
plt.axis("off")
|
75
|
+
plt.tight_layout(pad=0)
|
76
|
+
fig.canvas.draw()
|
77
|
+
image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
|
78
|
+
plt.close(fig)
|
79
|
+
return image
|
80
|
+
|
81
|
+
|
82
|
+
def geography_fn(place, buffer=400):
|
83
|
+
bbox = get_bbox(place, buffer)
|
84
|
+
map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
|
85
|
+
gdf = gpd.GeoDataFrame(map_data)
|
86
|
+
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
|
87
|
+
"EPSG:3857"
|
88
|
+
)
|
89
|
+
raster = raster_fn(gdf, shape=(buffer, buffer))
|
90
|
+
basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
|
91
|
+
# 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
|
92
|
+
kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
93
|
+
trans = lambda x: jnp.rot90(x, 3)
|
94
|
+
terrain = tps.Terrain(
|
95
|
+
building=trans(raster[0]),
|
96
|
+
water=trans(
|
97
|
+
raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
|
98
|
+
),
|
99
|
+
forest=trans(jnp.logical_or(raster[3], raster[4])),
|
100
|
+
basemap=basemap,
|
101
|
+
)
|
102
|
+
return terrain
|
103
|
+
|
104
|
+
|
105
|
+
def raster_fn(gdf, shape) -> Array:
|
106
|
+
bbox = gdf.total_bounds
|
107
|
+
t = transform.from_bounds(*bbox, *shape) # type: ignore
|
108
|
+
raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
|
109
|
+
return raster
|
110
|
+
|
111
|
+
|
112
|
+
def feature_fn(t, feature, gdf, shape):
|
113
|
+
if feature not in gdf.columns:
|
114
|
+
return jnp.zeros(shape)
|
115
|
+
gdf = gdf[~gdf[feature].isna()]
|
116
|
+
raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
|
117
|
+
return raster
|
118
|
+
|
119
|
+
|
120
|
+
# %%
|
121
|
+
if __name__ == "__main__":
|
122
|
+
place = "Thun, Switzerland"
|
123
|
+
terrain = geography_fn(place, 300)
|
124
|
+
|
125
|
+
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
|
126
|
+
axes[0].imshow(terrain.building, cmap="gray")
|
127
|
+
axes[1].imshow(terrain.water, cmap="gray")
|
128
|
+
axes[2].imshow(terrain.forest, cmap="gray")
|
129
|
+
axes[3].imshow(terrain.building + terrain.water + terrain.forest)
|
130
|
+
axes[4].imshow(terrain.basemap)
|
parabellum/pcg.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
# pcg.py
|
2
|
+
# procedural content generation
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from jax import random, vmap
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import matplotlib.pyplot as plt
|
9
|
+
from functools import partial
|
10
|
+
|
11
|
+
|
12
|
+
# %% Functions
|
13
|
+
seed = 0
|
14
|
+
n = 100
|
15
|
+
rng = random.PRNGKey(seed)
|
16
|
+
Y = random.uniform(rng, (n,))
|
17
|
+
|
18
|
+
|
19
|
+
def g(t):
|
20
|
+
return (1 - jnp.cos(jnp.pi * t)) / 2
|
21
|
+
|
22
|
+
|
23
|
+
def lerp(a, b, t):
|
24
|
+
t -= jnp.floor(t) # the fractional part of t
|
25
|
+
return (1 - t) * a + t * b
|
26
|
+
|
27
|
+
|
28
|
+
def cerp(a, b, t):
|
29
|
+
t -= jnp.floor(t) # the fractional part of t
|
30
|
+
return g(1 - t) * a + g(t) * b
|
31
|
+
|
32
|
+
|
33
|
+
def body_fn(x):
|
34
|
+
i = jnp.floor(x).astype(jnp.uint8)
|
35
|
+
return cerp(Y[i], Y[i + 1], x)
|
36
|
+
|
37
|
+
|
38
|
+
@partial(vmap, in_axes=(None, 0, None))
|
39
|
+
def noise_fn(y, t, n):
|
40
|
+
return y[t % n]
|
41
|
+
|
42
|
+
|
43
|
+
@vmap
|
44
|
+
def perlin_fn(t):
|
45
|
+
return noise_fn(Y, t * jnp.arange(n * 3), n)
|
46
|
+
|
47
|
+
|
48
|
+
xs = jnp.linspace(0, 1, 1000)
|
49
|
+
noise = perlin_fn(2 ** jnp.arange(3)).sum(0)
|
50
|
+
|
51
|
+
fig, ax = plt.subplots(figsize=(20, 4), dpi=100)
|
52
|
+
ax.set_ylim(0, 1)
|
53
|
+
ax.plot(noise / noise.max())
|
parabellum/run.py
CHANGED
@@ -21,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
|
21
21
|
|
22
22
|
|
23
23
|
# types
|
24
|
-
State = jaxmarl.environments.smax.smax_env.State
|
24
|
+
State = jaxmarl.environments.smax.smax_env.State # type: ignore
|
25
25
|
Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
|
26
26
|
StateSeq = List[Tuple[jnp.ndarray, State, Action]]
|
27
27
|
|
parabellum/terrain_db.py
ADDED
@@ -0,0 +1,117 @@
|
|
1
|
+
# %%
|
2
|
+
import numpy as np
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from parabellum import tps
|
5
|
+
|
6
|
+
|
7
|
+
# %%
|
8
|
+
def map_raster_from_line(raster, line, size):
|
9
|
+
x0, y0, dx, dy = line
|
10
|
+
x0 = int(x0*size)
|
11
|
+
y0 = int(y0*size)
|
12
|
+
dx = int(dx*size)
|
13
|
+
dy = int(dy*size)
|
14
|
+
max_T = int(2**0.5 * size)
|
15
|
+
for t in range(max_T+1):
|
16
|
+
alpha = t/float(max_T)
|
17
|
+
x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0+dx))
|
18
|
+
y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0+dy))
|
19
|
+
if 0<=x<size and 0<=y<size:
|
20
|
+
raster[x, y] = 1
|
21
|
+
return raster
|
22
|
+
|
23
|
+
|
24
|
+
# %%
|
25
|
+
def map_raster_from_rect(raster, rect, size):
|
26
|
+
x0, y0, dx, dy = rect
|
27
|
+
x0 = int(x0*size)
|
28
|
+
y0 = int(y0*size)
|
29
|
+
dx = int(dx*size)
|
30
|
+
dy = int(dy*size)
|
31
|
+
raster[x0:x0+dx, y0:y0+dy] = 1
|
32
|
+
return raster
|
33
|
+
|
34
|
+
|
35
|
+
# %%
|
36
|
+
building_color = jnp.array([201,199,198, 255])
|
37
|
+
water_color = jnp.array([193, 237, 254, 255])
|
38
|
+
forest_color = jnp.array([197,214,185, 255])
|
39
|
+
empty_color = jnp.array([255, 255, 255, 255])
|
40
|
+
|
41
|
+
def make_terrain(terrain_args, size):
|
42
|
+
args = {}
|
43
|
+
for key, config in terrain_args.items():
|
44
|
+
raster = np.zeros((size, size))
|
45
|
+
if config is not None:
|
46
|
+
for elem in config:
|
47
|
+
if "line" in elem:
|
48
|
+
raster = map_raster_from_line(raster, elem["line"], size)
|
49
|
+
elif "rect" in elem:
|
50
|
+
raster = map_raster_from_rect(raster, elem["rect"], size)
|
51
|
+
args[key] = jnp.array(raster.T)
|
52
|
+
basemap = jnp.where(args["building"][:,:,None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size,size, 1)))
|
53
|
+
basemap = jnp.where(args["water"][:,:,None], jnp.tile(water_color, (size, size, 1)), basemap)
|
54
|
+
basemap = jnp.where(args["forest"][:,:,None], jnp.tile(forest_color, (size, size, 1)), basemap)
|
55
|
+
args["basemap"] = basemap
|
56
|
+
return tps.Terrain(**args)
|
57
|
+
|
58
|
+
|
59
|
+
# %%
|
60
|
+
db = {
|
61
|
+
"blank": {'building': None, 'water': None, 'forest': None},
|
62
|
+
"F": {'building': [{"line": [0.25, 0.33, 0.5, 0]}, {"line":[0.75, 0.33, 0., 0.25]}, {"line":[0.50, 0.33, 0., 0.25]}], 'water': None, 'forest': None},
|
63
|
+
"stronghold": {'building': [
|
64
|
+
{"line":[0.2, 0.275, 0.2, 0.]}, {"line":[0.2, 0.275, 0.0, 0.2]},
|
65
|
+
{"line":[0.4, 0.275, 0.0, 0.2]}, {"line":[0.2, 0.475, 0.2, 0.]},
|
66
|
+
|
67
|
+
{"line":[0.2, 0.525, 0.2, 0.]}, {"line": [0.2, 0.525, 0.0, 0.2]},
|
68
|
+
{"line":[0.4, 0.525, 0.0, 0.2]}, {"line": [0.2, 0.725, 0.525, 0.]},
|
69
|
+
|
70
|
+
{"line":[0.75, 0.25, 0., 0.2]}, {"line":[0.75, 0.55, 0., 0.19]},
|
71
|
+
{"line":[0.6, 0.25, 0.15, 0.]}], 'water': None, 'forest': None},
|
72
|
+
"playground": {'building': [{"line":[0.5, 0.5, 0.5, 0.]}], 'water': None, 'forest': None},
|
73
|
+
"water_park": {
|
74
|
+
'building': [{"line":[0.5, 0.5, 0.5, 0.]}],
|
75
|
+
"water": [{"rect":[0., 0.8, 0.1, 0.05]}, {"rect": [0.2, 0.8, 0.8, 0.05]}],
|
76
|
+
"forest": [{"rect": [0., 0., 1., 0.2]}]
|
77
|
+
},
|
78
|
+
"triangle": {'building': [{"line": [0.33, 0., 0., 1.]}, {"line": [0.66, 0., 0., 1.]}], 'water': None, 'forest': None},
|
79
|
+
"u_shape": {
|
80
|
+
'building': [],
|
81
|
+
"water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
|
82
|
+
"forest": []
|
83
|
+
},
|
84
|
+
}
|
85
|
+
|
86
|
+
# %% [raw]
|
87
|
+
# import matplotlib.pyplot as plt
|
88
|
+
# size = 50
|
89
|
+
# raster = np.zeros((size, size))
|
90
|
+
# rect = [0.2, 0.3, 0.05, 0.4]
|
91
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
92
|
+
# rect = [0.4, 0.3, 0.05, 0.4]
|
93
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
94
|
+
# rect = [0.2, 0.3, 0.25, 0.05]
|
95
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
96
|
+
# rect = [0.2, 0.7, 0.25, 0.05]
|
97
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
98
|
+
# rect = [0.6, 0.3, 0.4, 0.45]
|
99
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
100
|
+
# plt.imshow(jnp.rot90(raster))
|
101
|
+
|
102
|
+
# %% [markdown]
|
103
|
+
# # Main
|
104
|
+
|
105
|
+
# %%
|
106
|
+
if __name__ == "__main__":
|
107
|
+
import matplotlib.pyplot as plt
|
108
|
+
|
109
|
+
# %%
|
110
|
+
terrain = make_terrain(db["u_shape"], size=50)
|
111
|
+
|
112
|
+
# %%
|
113
|
+
plt.imshow(jnp.rot90(terrain.basemap))
|
114
|
+
|
115
|
+
# %%
|
116
|
+
|
117
|
+
# %%
|
parabellum/tps.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
# %%
|
2
|
+
# tps.py
|
3
|
+
# parabellum types and dataclasses
|
4
|
+
# by: Noah Syrkis
|
5
|
+
|
6
|
+
# %% Imports
|
7
|
+
from chex import dataclass
|
8
|
+
from jaxtyping import Array
|
9
|
+
|
10
|
+
|
11
|
+
# %% Dataclasses
|
12
|
+
@dataclass
|
13
|
+
class Terrain:
|
14
|
+
building: Array
|
15
|
+
water: Array
|
16
|
+
forest: Array
|
17
|
+
basemap: Array
|
parabellum/vis.py
CHANGED
@@ -3,26 +3,19 @@ Visualizer for the Parabellum environment
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
# Standard library imports
|
6
|
-
from
|
7
|
-
from typing import Optional, List, Tuple
|
6
|
+
from typing import Optional, Tuple
|
8
7
|
import cv2
|
9
|
-
from PIL import Image
|
10
8
|
|
11
9
|
# JAX and JAX-related imports
|
12
10
|
import jax
|
13
11
|
from chex import dataclass
|
14
|
-
import
|
15
|
-
from jax import vmap, tree_util, Array, jit
|
12
|
+
from jax import vmap, Array
|
16
13
|
import jax.numpy as jnp
|
17
|
-
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
18
|
-
from jaxmarl.environments.smax import SMAX
|
19
14
|
from jaxmarl.viz.visualizer import SMAXVisualizer
|
20
15
|
|
21
16
|
# Third-party imports
|
22
17
|
import numpy as np
|
23
18
|
import pygame
|
24
|
-
import cv2
|
25
|
-
from tqdm import tqdm
|
26
19
|
|
27
20
|
# Local imports
|
28
21
|
import parabellum as pb
|
@@ -35,12 +28,12 @@ class Skin:
|
|
35
28
|
maskmap: Array # maskmap of buildings
|
36
29
|
bg: Tuple[int, int, int] = (255, 255, 255)
|
37
30
|
fg: Tuple[int, int, int] = (0, 0, 0)
|
38
|
-
ally: Tuple[int, int, int]
|
39
|
-
enemy: Tuple[int, int, int]
|
40
|
-
pad: int
|
41
|
-
size: int
|
31
|
+
ally: Tuple[int, int, int] = (0, 255, 0)
|
32
|
+
enemy: Tuple[int, int, int] = (255, 0, 0)
|
33
|
+
pad: int = 100
|
34
|
+
size: int = 1000 # excluding padding
|
42
35
|
fps: int = 24
|
43
|
-
vis_size: int = 1000
|
36
|
+
vis_size: int = 1000 # size of the map in Vis (exluding padding)
|
44
37
|
scale: Optional[float] = None
|
45
38
|
|
46
39
|
|
@@ -57,10 +50,23 @@ class Visualizer(SMAXVisualizer):
|
|
57
50
|
self.env = env
|
58
51
|
|
59
52
|
def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
|
60
|
-
expanded_state_seq, expanded_action_seq = expand_fn(
|
61
|
-
|
62
|
-
|
63
|
-
|
53
|
+
expanded_state_seq, expanded_action_seq = expand_fn(
|
54
|
+
self.env, self.state_seq, self.action_seq
|
55
|
+
)
|
56
|
+
state_seq_seq, action_seq_seq = unbatch_fn(
|
57
|
+
expanded_state_seq, expanded_action_seq
|
58
|
+
)
|
59
|
+
for idx, (state_seq, action_seq) in enumerate(
|
60
|
+
zip(state_seq_seq, action_seq_seq)
|
61
|
+
):
|
62
|
+
animate_fn(
|
63
|
+
self.env,
|
64
|
+
self.skin,
|
65
|
+
self.image,
|
66
|
+
state_seq,
|
67
|
+
action_seq,
|
68
|
+
f"{save_fname}_{idx}.mp4",
|
69
|
+
)
|
64
70
|
|
65
71
|
|
66
72
|
# functions
|
@@ -70,22 +76,29 @@ def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
|
|
70
76
|
for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
|
71
77
|
frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
|
72
78
|
# use cv2 to write frames to video
|
73
|
-
fourcc = cv2.VideoWriter_fourcc(*
|
74
|
-
out = cv2.VideoWriter(
|
79
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
|
80
|
+
out = cv2.VideoWriter(
|
81
|
+
save_fname,
|
82
|
+
fourcc,
|
83
|
+
skin.fps,
|
84
|
+
(skin.size + skin.pad * 2, skin.size + skin.pad * 2),
|
85
|
+
)
|
75
86
|
for frame in frames:
|
76
87
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
77
88
|
out.release()
|
78
89
|
pygame.quit()
|
79
90
|
|
80
91
|
|
81
|
-
def init_frame(
|
92
|
+
def init_frame(
|
93
|
+
env, skin, image, state: pb.State, action: Array, idx: int
|
94
|
+
) -> pygame.Surface:
|
82
95
|
dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
|
83
96
|
frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
|
84
97
|
return frame
|
85
98
|
|
86
99
|
|
87
100
|
def transform_frame(env, skin, frame):
|
88
|
-
#frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
101
|
+
# frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
89
102
|
frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
|
90
103
|
return frame
|
91
104
|
|
@@ -102,7 +115,7 @@ def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.n
|
|
102
115
|
|
103
116
|
|
104
117
|
def render_background(env, skin, image, frame, state, action):
|
105
|
-
coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
|
118
|
+
coords = (skin.pad - 5, skin.pad - 5, skin.size + 10, skin.size + 10)
|
106
119
|
frame.fill(skin.bg)
|
107
120
|
frame.blit(image, coords)
|
108
121
|
pygame.draw.rect(frame, skin.fg, coords, 3)
|
@@ -116,6 +129,7 @@ def render_action(env, skin, image, frame, state, action):
|
|
116
129
|
def render_bullet(env, skin, image, frame, state, action):
|
117
130
|
return frame
|
118
131
|
|
132
|
+
|
119
133
|
def render_agents(env, skin, image, frame, state, action):
|
120
134
|
units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
|
121
135
|
for idx, (pos, team, kind, health) in enumerate(zip(*units)):
|
@@ -136,7 +150,11 @@ def text_fn(text):
|
|
136
150
|
|
137
151
|
def image_fn(skin: Skin): # TODO:
|
138
152
|
"""Create an image for background (basemap or maskmap)"""
|
139
|
-
motif = cv2.resize(
|
153
|
+
motif = cv2.resize(
|
154
|
+
np.array(skin.maskmap.T),
|
155
|
+
(skin.size, skin.size),
|
156
|
+
interpolation=cv2.INTER_NEAREST,
|
157
|
+
).astype(np.uint8)
|
140
158
|
motif = (motif > 0).astype(np.uint8)
|
141
159
|
image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
|
142
160
|
image[motif == 1] = skin.fg
|
@@ -150,7 +168,9 @@ def unbatch_fn(state_seq, action_seq):
|
|
150
168
|
if is_multi_run(state_seq):
|
151
169
|
n_envs = state_seq[0][1].unit_positions.shape[0]
|
152
170
|
state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
|
153
|
-
action_seq_seq = [
|
171
|
+
action_seq_seq = [
|
172
|
+
jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)
|
173
|
+
]
|
154
174
|
else:
|
155
175
|
state_seq_seq = [state_seq]
|
156
176
|
action_seq_seq = [action_seq]
|
@@ -161,7 +181,9 @@ def expand_fn(env, state_seq, action_seq):
|
|
161
181
|
"""Expand the state sequence"""
|
162
182
|
fn = env.expand_state_seq
|
163
183
|
state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
|
164
|
-
action_seq = [
|
184
|
+
action_seq = [
|
185
|
+
action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))
|
186
|
+
]
|
165
187
|
return state_seq, action_seq
|
166
188
|
|
167
189
|
|
@@ -1,42 +1,44 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: parabellum
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.1
|
4
4
|
Summary: Parabellum environment for parallel warfare simulation
|
5
5
|
Home-page: https://github.com/syrkis/parabellum
|
6
6
|
License: MIT
|
7
7
|
Keywords: warfare,simulation,parallel,environment
|
8
8
|
Author: Noah Syrkis
|
9
9
|
Author-email: desk@syrkis.com
|
10
|
-
Requires-Python: >=3.11,<
|
10
|
+
Requires-Python: >=3.11,<3.12
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
14
|
-
|
14
|
+
Requires-Dist: cartopy (>=0.23.0,<0.24.0)
|
15
15
|
Requires-Dist: contextily (>=1.6.0,<2.0.0)
|
16
16
|
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
17
|
+
Requires-Dist: einops (>=0.8.0,<0.9.0)
|
17
18
|
Requires-Dist: folium (>=0.17.0,<0.18.0)
|
18
|
-
Requires-Dist: geopandas (>=1.0.0,<2.0.0)
|
19
19
|
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
20
20
|
Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
|
21
21
|
Requires-Dist: jax (==0.4.17)
|
22
22
|
Requires-Dist: jaxmarl (==0.0.3)
|
23
|
+
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
23
24
|
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
24
25
|
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
25
26
|
Requires-Dist: numpy (<2)
|
26
27
|
Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
|
27
|
-
Requires-Dist: osmnx (
|
28
|
+
Requires-Dist: osmnx (==2.0.0b0)
|
28
29
|
Requires-Dist: pandas (>=2.2.2,<3.0.0)
|
29
30
|
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
30
31
|
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
31
32
|
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
32
33
|
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
34
|
+
Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
|
33
35
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
34
36
|
Project-URL: Repository, https://github.com/syrkis/parabellum
|
35
37
|
Description-Content-Type: text/markdown
|
36
38
|
|
37
39
|
# Parabellum
|
38
40
|
|
39
|
-
Ultra-scalable JaxMARL based warfare simulation engine
|
41
|
+
Ultra-scalable JaxMARL based warfare simulation engine.
|
40
42
|
|
41
43
|
[](https://parabellum.readthedocs.io/en/latest/?badge=latest)
|
42
44
|
|
@@ -93,3 +95,4 @@ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.i
|
|
93
95
|
## License
|
94
96
|
|
95
97
|
MIT
|
98
|
+
|
@@ -0,0 +1,13 @@
|
|
1
|
+
parabellum/__init__.py,sha256=hIOLir7wgaf_HU4j8uos7PaCrofqPQcr3FcMlBsZyr8,406
|
2
|
+
parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
|
3
|
+
parabellum/env.py,sha256=0mDqQ7-OI-oufBMMBoUt72Kf5OvHr9thilLGzszlICY,22569
|
4
|
+
parabellum/geo.py,sha256=PwEwspOppTPrHIXDZB_nGPTnVFIvDzbh2WtqzVKMUaM,4198
|
5
|
+
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
6
|
+
parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
|
7
|
+
parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
|
8
|
+
parabellum/terrain_db.py,sha256=XTKlpLAi3ZwoVw4-KS-Eh15NKsBKP-yt8v6FJGUtwdM,3960
|
9
|
+
parabellum/tps.py,sha256=of-RBdelAbNCHQZd1I22RWmZkwUEh6f161mx0X_G2tE,257
|
10
|
+
parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
|
11
|
+
parabellum-0.3.1.dist-info/METADATA,sha256=RrSY6CrhwpVlbdJzacX2iVkh_MEgtZkZZCPHJWJJjqo,2707
|
12
|
+
parabellum-0.3.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
13
|
+
parabellum-0.3.1.dist-info/RECORD,,
|
parabellum/map.py
DELETED
@@ -1,100 +0,0 @@
|
|
1
|
-
# map.py
|
2
|
-
# parabellum map functions
|
3
|
-
# by: Noah Syrkis
|
4
|
-
|
5
|
-
# imports
|
6
|
-
import jax.numpy as jnp
|
7
|
-
from geopy.geocoders import Nominatim
|
8
|
-
import geopandas as gpd
|
9
|
-
import osmnx as ox
|
10
|
-
import contextily as cx
|
11
|
-
import matplotlib.pyplot as plt
|
12
|
-
from rasterio import features
|
13
|
-
import rasterio.transform
|
14
|
-
from typing import Optional, Tuple
|
15
|
-
from geopy.location import Location
|
16
|
-
from shapely.geometry import Point
|
17
|
-
import os
|
18
|
-
import pickle
|
19
|
-
|
20
|
-
# constants
|
21
|
-
geolocator = Nominatim(user_agent="parabellum")
|
22
|
-
BUILDING_TAGS = {"building": True}
|
23
|
-
|
24
|
-
def get_location(place: str) -> Tuple[float, float]:
|
25
|
-
"""Get coordinates for a given place."""
|
26
|
-
coords: Optional[Location] = geolocator.geocode(place) # type: ignore
|
27
|
-
if coords is None:
|
28
|
-
raise ValueError(f"Could not geocode the place: {place}")
|
29
|
-
return (coords.latitude, coords.longitude)
|
30
|
-
|
31
|
-
def get_building_geometry(point: Tuple[float, float], size: int) -> gpd.GeoDataFrame:
|
32
|
-
"""Get building geometry for a given point and size."""
|
33
|
-
geometry = ox.features_from_point(point, tags=BUILDING_TAGS, dist=size // 2)
|
34
|
-
return gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
|
35
|
-
|
36
|
-
def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
|
37
|
-
"""Rasterize geometry and return as a JAX array."""
|
38
|
-
w, s, e, n = gdf.total_bounds
|
39
|
-
transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
|
40
|
-
raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
|
41
|
-
return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
|
42
|
-
|
43
|
-
# +
|
44
|
-
def get_from_cache(place, size):
|
45
|
-
if os.path.exists("./cache"):
|
46
|
-
name = str(hash((place, size))) + ".pk"
|
47
|
-
if os.path.exists("./cache/" + name):
|
48
|
-
with open("./cache/" + name, "rb") as f:
|
49
|
-
(mask, base) = pickle.load(f)
|
50
|
-
return (mask, base.astype(jnp.int64))
|
51
|
-
return (None, None)
|
52
|
-
|
53
|
-
def save_in_cache(place, size, mask, base):
|
54
|
-
if not os.path.exists("./cache"):
|
55
|
-
os.makedirs("./cache")
|
56
|
-
name = str(hash((place, size))) + ".pk"
|
57
|
-
with open("./cache/" + name, "wb") as f:
|
58
|
-
pickle.dump((mask, base), f)
|
59
|
-
|
60
|
-
def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
61
|
-
"""Returns a rasterized map of buildings for a given location."""
|
62
|
-
if with_cache:
|
63
|
-
mask, base = get_from_cache(place, size)
|
64
|
-
if not with_cache or mask is None:
|
65
|
-
point = get_location(place)
|
66
|
-
gdf = get_building_geometry(point, size)
|
67
|
-
mask = rasterize_geometry(gdf, size)
|
68
|
-
base = get_basemap(place, size)
|
69
|
-
if with_cache:
|
70
|
-
save_in_cache(place, size, mask, base)
|
71
|
-
return mask, base
|
72
|
-
|
73
|
-
|
74
|
-
# -
|
75
|
-
|
76
|
-
def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
|
77
|
-
"""Returns a basemap for a given place as a JAX array."""
|
78
|
-
point = get_location(place)
|
79
|
-
gdf = get_building_geometry(point, size)
|
80
|
-
basemap, _ = cx.bounds2img(*gdf.total_bounds, ll=True)
|
81
|
-
# get the middle size x size square
|
82
|
-
basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
|
83
|
-
(basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
|
84
|
-
return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
|
85
|
-
|
86
|
-
|
87
|
-
if __name__ == "__main__":
|
88
|
-
place = "Cauvicourt, 14190, France"
|
89
|
-
mask, base = terrain_fn(place, 500)
|
90
|
-
|
91
|
-
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
|
92
|
-
ax[0].imshow(jnp.flip(mask,0)) # type: ignore
|
93
|
-
ax[1].imshow(base) # type: ignore
|
94
|
-
ax[2].imshow(base) # type: ignore
|
95
|
-
ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
|
96
|
-
plt.show()
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
@@ -1,10 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
-
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
-
parabellum/env.py,sha256=H1YGAtUYNJd8OHnZ3sOEXbag5L0WjtJHBGL8ymGPvoE,16898
|
4
|
-
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
-
parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
|
6
|
-
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
-
parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
|
8
|
-
parabellum-0.2.26.dist-info/METADATA,sha256=AJwdmHRRPG2MosgufeWqQH7LsDRtTvFFGMh1azey9zA,2671
|
9
|
-
parabellum-0.2.26.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
-
parabellum-0.2.26.dist-info/RECORD,,
|
File without changes
|