parabellum 0.2.25__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +9 -5
- parabellum/aid.py +22 -0
- parabellum/env.py +184 -95
- parabellum/geo.py +100 -0
- parabellum/map.py +87 -92
- parabellum/pcg.py +53 -0
- parabellum/run.py +1 -1
- parabellum/tps.py +16 -0
- parabellum/vis.py +48 -26
- {parabellum-0.2.25.dist-info → parabellum-0.3.0.dist-info}/METADATA +7 -5
- parabellum-0.3.0.dist-info/RECORD +13 -0
- parabellum-0.2.25.dist-info/RECORD +0 -10
- {parabellum-0.2.25.dist-info → parabellum-0.3.0.dist-info}/WHEEL +0 -0
parabellum/__init__.py
CHANGED
@@ -1,18 +1,22 @@
|
|
1
|
-
from .env import Environment, Scenario,
|
1
|
+
from .env import Environment, Scenario, make_scenario, State
|
2
2
|
from .vis import Visualizer, Skin
|
3
|
-
from .map import terrain_fn
|
4
3
|
from .gun import bullet_fn
|
5
|
-
|
4
|
+
from . import vis
|
5
|
+
from . import map
|
6
|
+
from . import env
|
7
|
+
from . import tps
|
6
8
|
# from .run import run
|
7
9
|
|
8
10
|
__all__ = [
|
11
|
+
"env",
|
12
|
+
"map",
|
13
|
+
"vis",
|
14
|
+
"tps",
|
9
15
|
"Environment",
|
10
16
|
"Scenario",
|
11
|
-
"scenarios",
|
12
17
|
"make_scenario",
|
13
18
|
"State",
|
14
19
|
"Visualizer",
|
15
20
|
"Skin",
|
16
|
-
"terrain_fn",
|
17
21
|
"bullet_fn",
|
18
22
|
]
|
parabellum/aid.py
CHANGED
@@ -3,3 +3,25 @@
|
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# imports
|
6
|
+
import os
|
7
|
+
from collections import namedtuple
|
8
|
+
from typing import Tuple
|
9
|
+
import cartopy.crs as ccrs
|
10
|
+
|
11
|
+
# types
|
12
|
+
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
13
|
+
|
14
|
+
|
15
|
+
# coordinate function
|
16
|
+
def to_mercator(bbox: BBox) -> BBox:
|
17
|
+
proj = ccrs.Mercator()
|
18
|
+
west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
|
19
|
+
east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
|
20
|
+
return BBox(north=north, south=south, east=east, west=west)
|
21
|
+
|
22
|
+
|
23
|
+
def to_platecarree(bbox: BBox) -> BBox:
|
24
|
+
proj = ccrs.PlateCarree()
|
25
|
+
west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
|
26
|
+
east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
|
27
|
+
return BBox(north=north, south=south, east=east, west=west)
|
parabellum/env.py
CHANGED
@@ -2,15 +2,14 @@
|
|
2
2
|
|
3
3
|
import jax.numpy as jnp
|
4
4
|
import jax
|
5
|
-
import
|
6
|
-
from jax import random, Array
|
7
|
-
from jax import jit
|
5
|
+
from jax import random, Array, vmap, jit
|
8
6
|
from flax.struct import dataclass
|
9
7
|
import chex
|
10
|
-
from jax import vmap
|
11
8
|
from jaxmarl.environments.smax.smax_env import SMAX
|
9
|
+
|
12
10
|
from typing import Tuple, Dict, cast
|
13
11
|
from functools import partial
|
12
|
+
from parabellum import tps, geo
|
14
13
|
|
15
14
|
|
16
15
|
@dataclass
|
@@ -18,8 +17,8 @@ class Scenario:
|
|
18
17
|
"""Parabellum scenario"""
|
19
18
|
|
20
19
|
place: str
|
21
|
-
terrain_raster:
|
22
|
-
unit_starting_sectors: jnp.ndarray
|
20
|
+
terrain_raster: tps.Terrain
|
21
|
+
unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
|
23
22
|
unit_types: chex.Array
|
24
23
|
num_allies: int
|
25
24
|
num_enemies: int
|
@@ -27,9 +26,11 @@ class Scenario:
|
|
27
26
|
smacv2_position_generation: bool = False
|
28
27
|
smacv2_unit_type_generation: bool = False
|
29
28
|
|
29
|
+
|
30
30
|
@dataclass
|
31
31
|
class State:
|
32
|
-
|
32
|
+
# terrain: Array
|
33
|
+
unit_positions: Array # fsfds
|
33
34
|
unit_alive: Array
|
34
35
|
unit_teams: Array
|
35
36
|
unit_health: Array
|
@@ -41,71 +42,89 @@ class State:
|
|
41
42
|
terminal: bool
|
42
43
|
|
43
44
|
|
44
|
-
# default scenario
|
45
|
-
scenarios = {
|
46
|
-
"default": Scenario(
|
47
|
-
"Identity Town",
|
48
|
-
jnp.eye(64, dtype=jnp.uint8),
|
49
|
-
jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
|
50
|
-
jnp.zeros((19,), dtype=jnp.uint8),
|
51
|
-
9,
|
52
|
-
10,
|
53
|
-
)
|
54
|
-
}
|
55
|
-
|
56
45
|
|
57
|
-
def make_scenario(
|
46
|
+
def make_scenario(
|
47
|
+
place,
|
48
|
+
size,
|
49
|
+
unit_starting_sectors,
|
50
|
+
allies_type,
|
51
|
+
n_allies,
|
52
|
+
enemies_type,
|
53
|
+
n_enemies,
|
54
|
+
):
|
55
|
+
terrain = geo.geography_fn(place, size)
|
56
|
+
if type(unit_starting_sectors) == list:
|
57
|
+
default_sector = [0, 0, size, size] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
|
58
|
+
correct_unit_starting_sectors = []
|
59
|
+
for i in range(n_allies+n_enemies):
|
60
|
+
selected_sector = None
|
61
|
+
for unit_ids, sector in unit_starting_sectors:
|
62
|
+
if i in unit_ids:
|
63
|
+
selected_sector = sector
|
64
|
+
if selected_sector is None:
|
65
|
+
selected_sector = default_sector
|
66
|
+
correct_unit_starting_sectors.append(selected_sector)
|
67
|
+
unit_starting_sectors = correct_unit_starting_sectors
|
58
68
|
if type(allies_type) == int:
|
59
69
|
allies = [allies_type] * n_allies
|
60
70
|
else:
|
61
|
-
assert
|
71
|
+
assert len(allies_type) == n_allies
|
62
72
|
allies = allies_type
|
63
|
-
|
73
|
+
|
64
74
|
if type(enemies_type) == int:
|
65
75
|
enemies = [enemies_type] * n_enemies
|
66
76
|
else:
|
67
|
-
assert
|
77
|
+
assert len(enemies_type) == n_enemies
|
68
78
|
enemies = enemies_type
|
69
79
|
unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
|
70
|
-
return Scenario(
|
80
|
+
return Scenario(
|
81
|
+
place, terrain, unit_starting_sectors, unit_types, n_allies, n_enemies # type: ignore
|
82
|
+
)
|
71
83
|
|
72
84
|
|
73
|
-
def spawn_fn(
|
85
|
+
def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
|
74
86
|
"""Spawns n agents on a map."""
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
return
|
83
|
-
|
84
|
-
|
85
|
-
def
|
86
|
-
"""
|
87
|
-
width, height
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
|
100
|
-
offset = jnp.array([coordx, coordy])
|
101
|
-
# sector is jnp.nonzero
|
102
|
-
return jnp.nonzero(sector.T), offset
|
87
|
+
spawn_positions = []
|
88
|
+
for sector in units_spawning_sectors:
|
89
|
+
rng, key_start, key_noise = random.split(rng, 3)
|
90
|
+
noise = random.uniform(key_noise, (2,)) * 0.5
|
91
|
+
idx = random.choice(key_start, sector[0].shape[0])
|
92
|
+
coord = jnp.array([sector[0][idx], sector[1][idx]])
|
93
|
+
spawn_positions.append(coord + noise)
|
94
|
+
return jnp.array(spawn_positions, dtype=jnp.float32)
|
95
|
+
|
96
|
+
|
97
|
+
def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
|
98
|
+
"""
|
99
|
+
sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
|
100
|
+
"""
|
101
|
+
width, height = invalid_spawn_areas.shape
|
102
|
+
spawning_sectors = []
|
103
|
+
for sector in sectors:
|
104
|
+
coordx, coordy = jnp.array(sector[0] * width, dtype=jnp.int32), jnp.array(sector[1] * height, dtype=jnp.int32)
|
105
|
+
sector = (invalid_spawn_areas[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0)
|
106
|
+
valid = jnp.nonzero(sector.T)
|
107
|
+
if valid[0].shape[0] == 0:
|
108
|
+
raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
|
109
|
+
spawning_sectors.append(jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1) ))
|
110
|
+
return spawning_sectors
|
103
111
|
|
104
112
|
|
105
113
|
class Environment(SMAX):
|
114
|
+
|
106
115
|
def __init__(self, scenario: Scenario, **kwargs):
|
107
|
-
map_height, map_width = scenario.terrain_raster.shape
|
116
|
+
map_height, map_width = scenario.terrain_raster.building.shape
|
108
117
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
118
|
+
if "unit_type_pushable" in kwargs:
|
119
|
+
self.unit_type_pushable = kwargs["unit_type_pushable"]
|
120
|
+
del kwargs["unit_type_pushable"]
|
121
|
+
else:
|
122
|
+
self.unit_type_pushable = jnp.array([1,1,0,0,0,1])
|
123
|
+
if "reset_when_done" in kwargs:
|
124
|
+
self.reset_when_done = kwargs["reset_when_done"]
|
125
|
+
del kwargs["reset_when_done"]
|
126
|
+
else:
|
127
|
+
self.reset_when_done = True
|
109
128
|
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
110
129
|
self.terrain_raster = scenario.terrain_raster
|
111
130
|
self.unit_starting_sectors = scenario.unit_starting_sectors
|
@@ -113,21 +132,18 @@ class Environment(SMAX):
|
|
113
132
|
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
114
133
|
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
115
134
|
self.scenario = scenario
|
116
|
-
self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
|
135
|
+
self.unit_type_velocities = jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5 if "unit_type_velocities" not in kwargs else kwargs["unit_type_velocities"]
|
117
136
|
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
118
137
|
self.max_steps = 200
|
119
|
-
self._push_units_away = lambda state, firmness
|
120
|
-
self.
|
121
|
-
self.
|
138
|
+
self._push_units_away = lambda state, firmness=1: state # overwrite push units
|
139
|
+
self.spawning_sectors = sectors_fn(self.unit_starting_sectors, scenario.terrain_raster.building + scenario.terrain_raster.water)
|
140
|
+
self.resolution = self.terrain_raster.building.shape[0] + self.terrain_raster.building.shape[1]
|
141
|
+
self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, self.resolution))
|
122
142
|
|
123
143
|
|
124
|
-
|
125
|
-
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
144
|
+
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
|
126
145
|
"""Environment-specific reset."""
|
127
|
-
|
128
|
-
team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
|
129
|
-
team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
|
130
|
-
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
146
|
+
unit_positions = spawn_fn(rng, self.spawning_sectors)
|
131
147
|
unit_teams = jnp.zeros((self.num_agents,))
|
132
148
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
133
149
|
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
@@ -145,13 +161,66 @@ class Environment(SMAX):
|
|
145
161
|
time=0,
|
146
162
|
terminal=False,
|
147
163
|
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
164
|
+
# terrain=self.terrain_raster,
|
148
165
|
)
|
149
|
-
state = self._push_units_away(state) # type: ignore
|
166
|
+
state = self._push_units_away(state) # type: ignore could be slow
|
150
167
|
obs = self.get_obs(state)
|
151
168
|
world_state = self.get_world_state(state)
|
152
|
-
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
169
|
+
# obs["world_state"] = jax.lax.stop_gradient(world_state)
|
153
170
|
return obs, state
|
154
171
|
|
172
|
+
def step_env(self, rng, state: State, action: Array): # type: ignore
|
173
|
+
obs, state, rewards, dones, infos = super().step_env(rng, state, action)
|
174
|
+
# delete world_state from obs
|
175
|
+
obs.pop("world_state")
|
176
|
+
if not self.reset_when_done:
|
177
|
+
for key in dones.keys():
|
178
|
+
dones[key] = False
|
179
|
+
return obs, state, rewards, dones, infos
|
180
|
+
|
181
|
+
def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
|
182
|
+
"""Applies observation function to state."""
|
183
|
+
|
184
|
+
def get_features(i, j):
|
185
|
+
"""Get features of unit j as seen from unit i"""
|
186
|
+
# Can just keep them symmetrical for now.
|
187
|
+
# j here means 'the jth unit that is not i'
|
188
|
+
# The observation is such that allies are always first
|
189
|
+
# so for units in the second team we count in reverse.
|
190
|
+
j = jax.lax.cond(
|
191
|
+
i < self.num_allies,
|
192
|
+
lambda: j,
|
193
|
+
lambda: self.num_agents - j - 1,
|
194
|
+
)
|
195
|
+
offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
|
196
|
+
j_idx = jax.lax.cond(
|
197
|
+
((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
|
198
|
+
lambda: j,
|
199
|
+
lambda: j + offset,
|
200
|
+
)
|
201
|
+
empty_features = jnp.zeros(shape=(len(self.unit_features),))
|
202
|
+
features = self._observe_features(state, i, j_idx)
|
203
|
+
visible = (
|
204
|
+
jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
|
205
|
+
< self.unit_type_sight_ranges[state.unit_types[i]]
|
206
|
+
)
|
207
|
+
return jax.lax.cond(
|
208
|
+
visible & state.unit_alive[i] & state.unit_alive[j_idx]
|
209
|
+
& self.has_line_of_sight(state.unit_positions[j_idx], state.unit_positions[i], self.terrain_raster.building + self.terrain_raster.forest),
|
210
|
+
lambda: features,
|
211
|
+
lambda: empty_features,
|
212
|
+
)
|
213
|
+
|
214
|
+
get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
|
215
|
+
get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
|
216
|
+
other_unit_obs = get_all_features(
|
217
|
+
jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
|
218
|
+
)
|
219
|
+
other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
|
220
|
+
get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
|
221
|
+
own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
|
222
|
+
obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
|
223
|
+
return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
|
155
224
|
|
156
225
|
def _our_push_units_away(
|
157
226
|
self, pos, unit_types, firmness: float = 1.0
|
@@ -171,7 +240,19 @@ class Environment(SMAX):
|
|
171
240
|
pos
|
172
241
|
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
173
242
|
)
|
174
|
-
return unit_positions
|
243
|
+
return jnp.where(self.unit_type_pushable[unit_types][:, None], unit_positions, pos)
|
244
|
+
|
245
|
+
def has_line_of_sight(self, source, target, raster_input): # this is tooooo slow TODO: make it fast
|
246
|
+
# we could compute this for units in sight only using a switch
|
247
|
+
|
248
|
+
cells = jnp.array(source[:, jnp.newaxis] * self.t + (1-self.t) * target[:, jnp.newaxis], dtype=jnp.int32)
|
249
|
+
|
250
|
+
mask = jnp.zeros(raster_input.shape).at[cells[1, :], cells[0, :]].set(1)
|
251
|
+
|
252
|
+
flag = ~jnp.any(jnp.logical_and(mask, raster_input))
|
253
|
+
|
254
|
+
return flag
|
255
|
+
|
175
256
|
|
176
257
|
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
177
258
|
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
@@ -180,25 +261,17 @@ class Environment(SMAX):
|
|
180
261
|
state: State,
|
181
262
|
actions: Tuple[chex.Array, chex.Array],
|
182
263
|
) -> State:
|
183
|
-
|
184
|
-
def intersect_fn(pos, new_pos, obs, obs_end):
|
185
|
-
d1 = jnp.cross(obs - pos, new_pos - pos)
|
186
|
-
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
187
|
-
d3 = jnp.cross(pos - obs, obs_end - obs)
|
188
|
-
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
189
|
-
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
190
|
-
|
191
|
-
def raster_crossing(pos, new_pos):
|
264
|
+
def raster_crossing(pos, new_pos, mask: jnp.ndarray):
|
192
265
|
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
193
|
-
raster = jnp.copy(self.terrain_raster)
|
194
266
|
minimum = jnp.minimum(pos, new_pos)
|
195
267
|
maximum = jnp.maximum(pos, new_pos)
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
return jnp.any(
|
201
|
-
|
268
|
+
mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask, 0)
|
269
|
+
mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask, 0)
|
270
|
+
mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask.T, 0).T
|
271
|
+
mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask.T, 0).T
|
272
|
+
return jnp.any(mask)
|
273
|
+
|
274
|
+
|
202
275
|
def update_position(idx, vec):
|
203
276
|
# Compute the movements slightly strangely.
|
204
277
|
# The velocities below are for diagonal directions
|
@@ -214,13 +287,13 @@ class Environment(SMAX):
|
|
214
287
|
)
|
215
288
|
# avoid going out of bounds
|
216
289
|
new_pos = jnp.maximum(
|
217
|
-
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
290
|
+
jnp.minimum(new_pos, jnp.array([self.map_width-1, self.map_height-1])),
|
218
291
|
jnp.zeros((2,)),
|
219
292
|
)
|
220
293
|
|
221
294
|
#######################################################################
|
222
295
|
############################################ avoid going into obstacles
|
223
|
-
clash = raster_crossing(pos, new_pos)
|
296
|
+
clash = raster_crossing(pos, new_pos, self.terrain_raster.building + self.terrain_raster.water)
|
224
297
|
new_pos = jnp.where(clash, pos, new_pos)
|
225
298
|
|
226
299
|
#######################################################################
|
@@ -285,7 +358,7 @@ class Environment(SMAX):
|
|
285
358
|
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
286
359
|
bystander_valid = (
|
287
360
|
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
288
|
-
.astype(jnp.bool_)
|
361
|
+
.astype(jnp.bool_) # type: ignore
|
289
362
|
.astype(jnp.float32)
|
290
363
|
)
|
291
364
|
bystander_health_diff = (
|
@@ -333,11 +406,14 @@ class Environment(SMAX):
|
|
333
406
|
|
334
407
|
# units push each other
|
335
408
|
new_pos = self._our_push_units_away(pos, state.unit_types)
|
336
|
-
clash = jax.vmap(raster_crossing)(pos, new_pos)
|
409
|
+
clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(pos, new_pos, self.terrain_raster.building + self.terrain_raster.water)
|
337
410
|
pos = jax.vmap(jnp.where)(clash, pos, new_pos)
|
338
411
|
# avoid going out of bounds
|
339
|
-
pos = jnp.maximum(
|
340
|
-
|
412
|
+
pos = jnp.maximum(
|
413
|
+
jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
|
414
|
+
jnp.zeros((2,)),
|
415
|
+
)
|
416
|
+
|
341
417
|
# Multiple enemies can attack the same unit.
|
342
418
|
# We have `(health_diff, attacked_idx)` pairs.
|
343
419
|
# `jax.lax.scatter_add` aggregates these exactly
|
@@ -380,26 +456,39 @@ class Environment(SMAX):
|
|
380
456
|
)
|
381
457
|
return state
|
382
458
|
|
383
|
-
|
384
459
|
if __name__ == "__main__":
|
385
460
|
n_envs = 4
|
386
461
|
|
387
|
-
|
462
|
+
|
463
|
+
n_allies = 10
|
464
|
+
scenario_kwargs = {"allies_type": 0, "n_allies": n_allies, "enemies_type": 0, "n_enemies": n_allies,
|
465
|
+
"place": "Vesterbro, Copenhagen, Denmark", "size": 256, "unit_starting_sectors":
|
466
|
+
[([i for i in range(n_allies)], [0.,0.45,0.1,0.1]), ([n_allies+i for i in range(n_allies)], [0.8,0.5,0.1,0.1])]}
|
467
|
+
scenario = make_scenario(**scenario_kwargs)
|
468
|
+
env = Environment(scenario)
|
388
469
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
389
470
|
reset_key = random.split(reset_rng, n_envs)
|
390
471
|
obs, state = vmap(env.reset)(reset_key)
|
391
472
|
state_seq = []
|
392
473
|
|
393
|
-
print(state.unit_positions)
|
394
|
-
exit()
|
395
474
|
|
396
|
-
|
475
|
+
from tqdm import tqdm
|
476
|
+
import time
|
477
|
+
step = vmap(jit(env.step))
|
478
|
+
tic = time.time()
|
479
|
+
for i in tqdm(range(10)):
|
397
480
|
rng, act_rng, step_rng = random.split(rng, 3)
|
398
481
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
482
|
+
print(tic - time.time())
|
399
483
|
act = {
|
400
484
|
a: vmap(env.action_space(a).sample)(act_key[i])
|
401
485
|
for i, a in enumerate(env.agents)
|
402
486
|
}
|
487
|
+
print(tic - time.time())
|
403
488
|
step_key = random.split(step_rng, n_envs)
|
489
|
+
print(tic - time.time())
|
404
490
|
state_seq.append((step_key, state, act))
|
405
|
-
|
491
|
+
print(tic - time.time())
|
492
|
+
obs, state, reward, done, infos = step(step_key, state, act)
|
493
|
+
print(tic - time.time())
|
494
|
+
tic = time.time()
|
parabellum/geo.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
# %% geo.py
|
2
|
+
# script for geospatial level generation
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from parabellum import tps
|
7
|
+
import rasterio
|
8
|
+
from rasterio import features, transform
|
9
|
+
from geopy.geocoders import Nominatim
|
10
|
+
from geopy.distance import distance
|
11
|
+
import contextily as cx
|
12
|
+
import jax.numpy as jnp
|
13
|
+
import cartopy.crs as ccrs
|
14
|
+
from jaxtyping import Array
|
15
|
+
import numpy as np
|
16
|
+
from shapely import box
|
17
|
+
import osmnx as ox
|
18
|
+
import geopandas as gpd
|
19
|
+
from collections import namedtuple
|
20
|
+
from typing import Tuple
|
21
|
+
import matplotlib.pyplot as plt
|
22
|
+
import seaborn as sns
|
23
|
+
import os
|
24
|
+
|
25
|
+
# %% Types
|
26
|
+
Coords = Tuple[float, float]
|
27
|
+
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
28
|
+
|
29
|
+
# %% Constants
|
30
|
+
provider = cx.providers.Stadia.StamenTerrain( # type: ignore
|
31
|
+
api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
|
32
|
+
)
|
33
|
+
provider["url"] = provider["url"] + "?api_key={api_key}"
|
34
|
+
tags = {"building": True, "water": True, "landuse": "forest"} # "road": True}
|
35
|
+
|
36
|
+
|
37
|
+
# %% Coordinate function
|
38
|
+
def get_coordinates(place: str) -> Coords:
|
39
|
+
geolocator = Nominatim(user_agent="parabellum")
|
40
|
+
point = geolocator.geocode(place)
|
41
|
+
return point.latitude, point.longitude # type: ignore
|
42
|
+
|
43
|
+
|
44
|
+
def get_bbox(place: str, buffer) -> BBox:
|
45
|
+
"""Get bounding box from place name in crs 4326."""
|
46
|
+
coords = get_coordinates(place)
|
47
|
+
north = distance(meters=buffer).destination(coords, bearing=0).latitude
|
48
|
+
south = distance(meters=buffer).destination(coords, bearing=180).latitude
|
49
|
+
east = distance(meters=buffer).destination(coords, bearing=90).longitude
|
50
|
+
west = distance(meters=buffer).destination(coords, bearing=270).longitude
|
51
|
+
return BBox(north, south, east, west)
|
52
|
+
|
53
|
+
|
54
|
+
def basemap_fn(bbox: BBox, gdf) -> Array:
|
55
|
+
fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
|
56
|
+
gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
|
57
|
+
cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
|
58
|
+
bbox = gdf.total_bounds
|
59
|
+
ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
|
60
|
+
plt.axis("off")
|
61
|
+
plt.tight_layout()
|
62
|
+
fig.canvas.draw()
|
63
|
+
image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
|
64
|
+
plt.close(fig)
|
65
|
+
return image
|
66
|
+
|
67
|
+
|
68
|
+
def geography_fn(place, buffer):
|
69
|
+
bbox = get_bbox(place, buffer)
|
70
|
+
map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
|
71
|
+
gdf = gpd.GeoDataFrame(map_data)
|
72
|
+
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
|
73
|
+
raster = raster_fn(gdf, shape=(buffer, buffer))
|
74
|
+
basemap = basemap_fn(bbox, gdf)
|
75
|
+
terrain = tps.Terrain(building=raster[0], water=raster[1], forest=raster[2], basemap=basemap)
|
76
|
+
return terrain
|
77
|
+
|
78
|
+
|
79
|
+
def raster_fn(gdf, shape) -> Array:
|
80
|
+
bbox = gdf.total_bounds
|
81
|
+
t = transform.from_bounds(*bbox, *shape) # type: ignore
|
82
|
+
raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in ["building", "water", "landuse"]])
|
83
|
+
return raster
|
84
|
+
|
85
|
+
def feature_fn(t, feature, gdf, shape):
|
86
|
+
if feature not in gdf.columns:
|
87
|
+
return jnp.zeros(shape)
|
88
|
+
gdf = gdf[~gdf[feature].isna()]
|
89
|
+
raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
|
90
|
+
return raster
|
91
|
+
|
92
|
+
place = "Thun, Switzerland"
|
93
|
+
terrain = geography_fn(place, 800)
|
94
|
+
# %%
|
95
|
+
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
|
96
|
+
axes[0].imshow(terrain.building, cmap="gray")
|
97
|
+
axes[1].imshow(terrain.water, cmap="gray")
|
98
|
+
axes[2].imshow(terrain.forest, cmap="gray")
|
99
|
+
axes[3].imshow(terrain.building + terrain.water + terrain.forest)
|
100
|
+
axes[4].imshow(terrain.basemap)
|
parabellum/map.py
CHANGED
@@ -1,100 +1,95 @@
|
|
1
|
-
#
|
2
|
-
|
1
|
+
# ludens.py
|
2
|
+
# script for fucking around and finding out
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
import
|
5
|
+
|
6
|
+
# %% Imports
|
7
|
+
# import parabellum as pb
|
8
|
+
import matplotlib.pyplot as plt
|
9
9
|
import osmnx as ox
|
10
|
+
from geopy.geocoders import Nominatim
|
11
|
+
import numpy as np
|
10
12
|
import contextily as cx
|
11
|
-
import
|
13
|
+
import jax.numpy as jnp
|
14
|
+
import geopandas as gpd
|
15
|
+
import rasterio
|
12
16
|
from rasterio import features
|
13
|
-
import rasterio.transform
|
14
|
-
from typing import Optional, Tuple
|
15
|
-
from geopy.location import Location
|
16
17
|
from shapely.geometry import Point
|
17
|
-
import
|
18
|
-
import pickle
|
18
|
+
from typing import List
|
19
19
|
|
20
|
-
#
|
20
|
+
# %% Constants
|
21
21
|
geolocator = Nominatim(user_agent="parabellum")
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
plt.show()
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
22
|
+
source = cx.providers.OpenStreetMap.Mapnik # type: ignore
|
23
|
+
|
24
|
+
|
25
|
+
def get_raster(
|
26
|
+
place: str, meters: int = 1000, tags: List[dict] | dict = {"building": True}
|
27
|
+
) -> jnp.ndarray:
|
28
|
+
# look here for tags https://wiki.openstreetmap.org/wiki/Map_features
|
29
|
+
def aux(place, tag):
|
30
|
+
"""Rasterize geometry and return as a JAX array."""
|
31
|
+
place = geolocator.geocode(place) # type: ignore
|
32
|
+
point = place.latitude, place.longitude # type: ignore # confusing order of lat/lon
|
33
|
+
geom = ox.features_from_point(point, tags=tag, dist=meters // 2)
|
34
|
+
gdf = gpd.GeoDataFrame(geom).set_crs("EPSG:4326")
|
35
|
+
# crop everythin outside of the meters x meters square
|
36
|
+
gdf = gdf.cx[
|
37
|
+
place.longitude - meters / 2 : place.longitude + meters / 2,
|
38
|
+
place.latitude - meters / 2 : place.latitude + meters / 2,
|
39
|
+
]
|
40
|
+
|
41
|
+
# bounds should be meters, meters
|
42
|
+
t = rasterio.transform.from_bounds(*bounds, meters, meters) # type: ignore
|
43
|
+
raster = features.rasterize(
|
44
|
+
gdf.geometry, out_shape=(meters, meters), transform=t
|
45
|
+
)
|
46
|
+
return jnp.array(raster)
|
47
|
+
|
48
|
+
if isinstance(tags, dict):
|
49
|
+
return aux(place, tags)
|
50
|
+
else:
|
51
|
+
return jnp.stack([aux(place, tag) for tag in tags])
|
52
|
+
|
53
|
+
|
54
|
+
def get_basemap(
|
55
|
+
place: str, size: int = 1000
|
56
|
+
) -> np.ndarray: # TODO: image is slightly off from raster. Fix this.
|
57
|
+
# Create a GeoDataFrame with the center point
|
58
|
+
place = geolocator.geocode(place) # type: ignore
|
59
|
+
lon, lat = place.longitude, place.latitude # type: ignore
|
60
|
+
gdf = gpd.GeoDataFrame(geometry=[Point(lon, lat)], crs="EPSG:4326")
|
61
|
+
gdf = gdf.to_crs("EPSG:3857")
|
62
|
+
|
63
|
+
# Create a buffer around the center point
|
64
|
+
# buffer = gdf.buffer(size) # type: ignore
|
65
|
+
buffer = gdf
|
66
|
+
bounds = buffer.total_bounds # i think this is wrong, since it ignores empty space
|
67
|
+
# modify bounds to include empty space
|
68
|
+
bounds = (bounds[0] - size, bounds[1] - size, bounds[2] + size, bounds[3] + size)
|
69
|
+
|
70
|
+
# Create a figure and axis
|
71
|
+
dpi = 300
|
72
|
+
fig, ax = plt.subplots(figsize=(size / dpi, size / dpi), dpi=dpi)
|
73
|
+
buffer.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=0)
|
74
|
+
|
75
|
+
# Calculate the zoom level for the basemap
|
76
|
+
|
77
|
+
# Add the basemap to the axis
|
78
|
+
cx.add_basemap(ax, source=source, zoom="auto", attribution=False)
|
79
|
+
|
80
|
+
# Set the x and y limits of the axis
|
81
|
+
ax.set_xlim(bounds[0], bounds[2])
|
82
|
+
ax.set_ylim(bounds[1], bounds[3])
|
83
|
+
|
84
|
+
# convert the image (without axis or border) to a numpy array
|
85
|
+
plt.axis("off")
|
86
|
+
plt.tight_layout()
|
87
|
+
|
88
|
+
# remove whitespace
|
89
|
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
90
|
+
fig.canvas.draw()
|
91
|
+
|
92
|
+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) # type: ignore
|
93
|
+
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
94
|
+
plt.close()
|
95
|
+
return jnp.array(image) # type: ignore
|
parabellum/pcg.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
# pcg.py
|
2
|
+
# procedural content generation
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from jax import random, vmap
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import matplotlib.pyplot as plt
|
9
|
+
from functools import partial
|
10
|
+
|
11
|
+
|
12
|
+
# %% Functions
|
13
|
+
seed = 0
|
14
|
+
n = 100
|
15
|
+
rng = random.PRNGKey(seed)
|
16
|
+
Y = random.uniform(rng, (n,))
|
17
|
+
|
18
|
+
|
19
|
+
def g(t):
|
20
|
+
return (1 - jnp.cos(jnp.pi * t)) / 2
|
21
|
+
|
22
|
+
|
23
|
+
def lerp(a, b, t):
|
24
|
+
t -= jnp.floor(t) # the fractional part of t
|
25
|
+
return (1 - t) * a + t * b
|
26
|
+
|
27
|
+
|
28
|
+
def cerp(a, b, t):
|
29
|
+
t -= jnp.floor(t) # the fractional part of t
|
30
|
+
return g(1 - t) * a + g(t) * b
|
31
|
+
|
32
|
+
|
33
|
+
def body_fn(x):
|
34
|
+
i = jnp.floor(x).astype(jnp.uint8)
|
35
|
+
return cerp(Y[i], Y[i + 1], x)
|
36
|
+
|
37
|
+
|
38
|
+
@partial(vmap, in_axes=(None, 0, None))
|
39
|
+
def noise_fn(y, t, n):
|
40
|
+
return y[t % n]
|
41
|
+
|
42
|
+
|
43
|
+
@vmap
|
44
|
+
def perlin_fn(t):
|
45
|
+
return noise_fn(Y, t * jnp.arange(n * 3), n)
|
46
|
+
|
47
|
+
|
48
|
+
xs = jnp.linspace(0, 1, 1000)
|
49
|
+
noise = perlin_fn(2 ** jnp.arange(3)).sum(0)
|
50
|
+
|
51
|
+
fig, ax = plt.subplots(figsize=(20, 4), dpi=100)
|
52
|
+
ax.set_ylim(0, 1)
|
53
|
+
ax.plot(noise / noise.max())
|
parabellum/run.py
CHANGED
@@ -21,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
|
21
21
|
|
22
22
|
|
23
23
|
# types
|
24
|
-
State = jaxmarl.environments.smax.smax_env.State
|
24
|
+
State = jaxmarl.environments.smax.smax_env.State # type: ignore
|
25
25
|
Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
|
26
26
|
StateSeq = List[Tuple[jnp.ndarray, State, Action]]
|
27
27
|
|
parabellum/tps.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
# tps.py
|
2
|
+
# parabellum types and dataclasses
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from chex import dataclass
|
7
|
+
from jaxtyping import Array
|
8
|
+
|
9
|
+
|
10
|
+
# %% Dataclasses
|
11
|
+
@dataclass
|
12
|
+
class Terrain:
|
13
|
+
building: Array
|
14
|
+
water: Array
|
15
|
+
forest: Array
|
16
|
+
basemap: Array
|
parabellum/vis.py
CHANGED
@@ -3,26 +3,19 @@ Visualizer for the Parabellum environment
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
# Standard library imports
|
6
|
-
from
|
7
|
-
from typing import Optional, List, Tuple
|
6
|
+
from typing import Optional, Tuple
|
8
7
|
import cv2
|
9
|
-
from PIL import Image
|
10
8
|
|
11
9
|
# JAX and JAX-related imports
|
12
10
|
import jax
|
13
11
|
from chex import dataclass
|
14
|
-
import
|
15
|
-
from jax import vmap, tree_util, Array, jit
|
12
|
+
from jax import vmap, Array
|
16
13
|
import jax.numpy as jnp
|
17
|
-
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
18
|
-
from jaxmarl.environments.smax import SMAX
|
19
14
|
from jaxmarl.viz.visualizer import SMAXVisualizer
|
20
15
|
|
21
16
|
# Third-party imports
|
22
17
|
import numpy as np
|
23
18
|
import pygame
|
24
|
-
import cv2
|
25
|
-
from tqdm import tqdm
|
26
19
|
|
27
20
|
# Local imports
|
28
21
|
import parabellum as pb
|
@@ -35,12 +28,12 @@ class Skin:
|
|
35
28
|
maskmap: Array # maskmap of buildings
|
36
29
|
bg: Tuple[int, int, int] = (255, 255, 255)
|
37
30
|
fg: Tuple[int, int, int] = (0, 0, 0)
|
38
|
-
ally: Tuple[int, int, int]
|
39
|
-
enemy: Tuple[int, int, int]
|
40
|
-
pad: int
|
41
|
-
size: int
|
31
|
+
ally: Tuple[int, int, int] = (0, 255, 0)
|
32
|
+
enemy: Tuple[int, int, int] = (255, 0, 0)
|
33
|
+
pad: int = 100
|
34
|
+
size: int = 1000 # excluding padding
|
42
35
|
fps: int = 24
|
43
|
-
vis_size: int = 1000
|
36
|
+
vis_size: int = 1000 # size of the map in Vis (exluding padding)
|
44
37
|
scale: Optional[float] = None
|
45
38
|
|
46
39
|
|
@@ -57,10 +50,23 @@ class Visualizer(SMAXVisualizer):
|
|
57
50
|
self.env = env
|
58
51
|
|
59
52
|
def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
|
60
|
-
expanded_state_seq, expanded_action_seq = expand_fn(
|
61
|
-
|
62
|
-
|
63
|
-
|
53
|
+
expanded_state_seq, expanded_action_seq = expand_fn(
|
54
|
+
self.env, self.state_seq, self.action_seq
|
55
|
+
)
|
56
|
+
state_seq_seq, action_seq_seq = unbatch_fn(
|
57
|
+
expanded_state_seq, expanded_action_seq
|
58
|
+
)
|
59
|
+
for idx, (state_seq, action_seq) in enumerate(
|
60
|
+
zip(state_seq_seq, action_seq_seq)
|
61
|
+
):
|
62
|
+
animate_fn(
|
63
|
+
self.env,
|
64
|
+
self.skin,
|
65
|
+
self.image,
|
66
|
+
state_seq,
|
67
|
+
action_seq,
|
68
|
+
f"{save_fname}_{idx}.mp4",
|
69
|
+
)
|
64
70
|
|
65
71
|
|
66
72
|
# functions
|
@@ -70,22 +76,29 @@ def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
|
|
70
76
|
for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
|
71
77
|
frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
|
72
78
|
# use cv2 to write frames to video
|
73
|
-
fourcc = cv2.VideoWriter_fourcc(*
|
74
|
-
out = cv2.VideoWriter(
|
79
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
|
80
|
+
out = cv2.VideoWriter(
|
81
|
+
save_fname,
|
82
|
+
fourcc,
|
83
|
+
skin.fps,
|
84
|
+
(skin.size + skin.pad * 2, skin.size + skin.pad * 2),
|
85
|
+
)
|
75
86
|
for frame in frames:
|
76
87
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
77
88
|
out.release()
|
78
89
|
pygame.quit()
|
79
90
|
|
80
91
|
|
81
|
-
def init_frame(
|
92
|
+
def init_frame(
|
93
|
+
env, skin, image, state: pb.State, action: Array, idx: int
|
94
|
+
) -> pygame.Surface:
|
82
95
|
dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
|
83
96
|
frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
|
84
97
|
return frame
|
85
98
|
|
86
99
|
|
87
100
|
def transform_frame(env, skin, frame):
|
88
|
-
#frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
101
|
+
# frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
89
102
|
frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
|
90
103
|
return frame
|
91
104
|
|
@@ -102,7 +115,7 @@ def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.n
|
|
102
115
|
|
103
116
|
|
104
117
|
def render_background(env, skin, image, frame, state, action):
|
105
|
-
coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
|
118
|
+
coords = (skin.pad - 5, skin.pad - 5, skin.size + 10, skin.size + 10)
|
106
119
|
frame.fill(skin.bg)
|
107
120
|
frame.blit(image, coords)
|
108
121
|
pygame.draw.rect(frame, skin.fg, coords, 3)
|
@@ -116,6 +129,7 @@ def render_action(env, skin, image, frame, state, action):
|
|
116
129
|
def render_bullet(env, skin, image, frame, state, action):
|
117
130
|
return frame
|
118
131
|
|
132
|
+
|
119
133
|
def render_agents(env, skin, image, frame, state, action):
|
120
134
|
units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
|
121
135
|
for idx, (pos, team, kind, health) in enumerate(zip(*units)):
|
@@ -136,7 +150,11 @@ def text_fn(text):
|
|
136
150
|
|
137
151
|
def image_fn(skin: Skin): # TODO:
|
138
152
|
"""Create an image for background (basemap or maskmap)"""
|
139
|
-
motif = cv2.resize(
|
153
|
+
motif = cv2.resize(
|
154
|
+
np.array(skin.maskmap.T),
|
155
|
+
(skin.size, skin.size),
|
156
|
+
interpolation=cv2.INTER_NEAREST,
|
157
|
+
).astype(np.uint8)
|
140
158
|
motif = (motif > 0).astype(np.uint8)
|
141
159
|
image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
|
142
160
|
image[motif == 1] = skin.fg
|
@@ -150,7 +168,9 @@ def unbatch_fn(state_seq, action_seq):
|
|
150
168
|
if is_multi_run(state_seq):
|
151
169
|
n_envs = state_seq[0][1].unit_positions.shape[0]
|
152
170
|
state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
|
153
|
-
action_seq_seq = [
|
171
|
+
action_seq_seq = [
|
172
|
+
jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)
|
173
|
+
]
|
154
174
|
else:
|
155
175
|
state_seq_seq = [state_seq]
|
156
176
|
action_seq_seq = [action_seq]
|
@@ -161,7 +181,9 @@ def expand_fn(env, state_seq, action_seq):
|
|
161
181
|
"""Expand the state sequence"""
|
162
182
|
fn = env.expand_state_seq
|
163
183
|
state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
|
164
|
-
action_seq = [
|
184
|
+
action_seq = [
|
185
|
+
action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))
|
186
|
+
]
|
165
187
|
return state_seq, action_seq
|
166
188
|
|
167
189
|
|
@@ -1,35 +1,37 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: parabellum
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.0
|
4
4
|
Summary: Parabellum environment for parallel warfare simulation
|
5
5
|
Home-page: https://github.com/syrkis/parabellum
|
6
6
|
License: MIT
|
7
7
|
Keywords: warfare,simulation,parallel,environment
|
8
8
|
Author: Noah Syrkis
|
9
9
|
Author-email: desk@syrkis.com
|
10
|
-
Requires-Python: >=3.11,<
|
10
|
+
Requires-Python: >=3.11,<3.12
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
14
|
-
|
14
|
+
Requires-Dist: cartopy (>=0.23.0,<0.24.0)
|
15
15
|
Requires-Dist: contextily (>=1.6.0,<2.0.0)
|
16
16
|
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
17
|
+
Requires-Dist: einops (>=0.8.0,<0.9.0)
|
17
18
|
Requires-Dist: folium (>=0.17.0,<0.18.0)
|
18
|
-
Requires-Dist: geopandas (>=1.0.0,<2.0.0)
|
19
19
|
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
20
20
|
Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
|
21
21
|
Requires-Dist: jax (==0.4.17)
|
22
22
|
Requires-Dist: jaxmarl (==0.0.3)
|
23
|
+
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
23
24
|
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
24
25
|
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
25
26
|
Requires-Dist: numpy (<2)
|
26
27
|
Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
|
27
|
-
Requires-Dist: osmnx (
|
28
|
+
Requires-Dist: osmnx (==2.0.0b0)
|
28
29
|
Requires-Dist: pandas (>=2.2.2,<3.0.0)
|
29
30
|
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
30
31
|
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
31
32
|
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
32
33
|
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
34
|
+
Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
|
33
35
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
34
36
|
Project-URL: Repository, https://github.com/syrkis/parabellum
|
35
37
|
Description-Content-Type: text/markdown
|
@@ -0,0 +1,13 @@
|
|
1
|
+
parabellum/__init__.py,sha256=vqQbvsTT_zcLThZ7fLoJ6cMAZbEeGIJDFyCkHmovfOY,392
|
2
|
+
parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
|
3
|
+
parabellum/env.py,sha256=VV3VK7TTkianihqJopRbY0vlRWOquu-VTrc9ep0PSTk,21304
|
4
|
+
parabellum/geo.py,sha256=xkj6iJqN076tRbaG38Sq7gtwKSNzxI37msRLnpn5JV0,3561
|
5
|
+
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
6
|
+
parabellum/map.py,sha256=9AV0PIqInXcWWojzHshy3X42Nm3ZDq0O1NG-6fQ9Wgw,3345
|
7
|
+
parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
|
8
|
+
parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
|
9
|
+
parabellum/tps.py,sha256=3tVqo42ggE8idZn500C0X2pS9TmYndgBzlAG7Yj2Wz8,252
|
10
|
+
parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
|
11
|
+
parabellum-0.3.0.dist-info/METADATA,sha256=FugXwz25bAPYKlIfqFc7dGVtPupse5zHYapmqBWopE8,2740
|
12
|
+
parabellum-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
13
|
+
parabellum-0.3.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
-
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
-
parabellum/env.py,sha256=H1VwxHj8KqbjBJ8b7NMxkAg3Q0qwVzGulrigc26Tzkc,16663
|
4
|
-
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
-
parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
|
6
|
-
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
-
parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
|
8
|
-
parabellum-0.2.25.dist-info/METADATA,sha256=UCuRLYhSUxnebs5pRs1FaWK67PUJw0kgsJ2UaaxNM9Q,2671
|
9
|
-
parabellum-0.2.25.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
-
parabellum-0.2.25.dist-info/RECORD,,
|
File without changes
|