parabellum 0.3.4__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +4 -18
- parabellum/aid.py +16 -2
- parabellum/env.py +109 -566
- parabellum/geo.py +29 -35
- parabellum/model.py +6 -0
- parabellum/ppo.py +1 -0
- parabellum/terrain_db.py +79 -47
- parabellum/types.py +54 -0
- {parabellum-0.3.4.dist-info → parabellum-0.5.13.dist-info}/METADATA +17 -15
- parabellum-0.5.13.dist-info/RECORD +15 -0
- {parabellum-0.3.4.dist-info → parabellum-0.5.13.dist-info}/WHEEL +1 -1
- parabellum/tps.py +0 -17
- parabellum-0.3.4.dist-info/RECORD +0 -13
parabellum/__init__.py
CHANGED
@@ -1,22 +1,8 @@
|
|
1
|
-
from .
|
2
|
-
from .vis import Visualizer, Skin
|
3
|
-
from .gun import bullet_fn
|
4
|
-
from . import vis
|
5
|
-
from . import terrain_db
|
6
|
-
from . import env
|
7
|
-
from . import tps
|
8
|
-
# from .run import run
|
1
|
+
from . import aid, env, geo, types
|
9
2
|
|
10
3
|
__all__ = [
|
4
|
+
"geo",
|
11
5
|
"env",
|
12
|
-
"
|
13
|
-
"
|
14
|
-
"tps",
|
15
|
-
"Environment",
|
16
|
-
"Scenario",
|
17
|
-
"make_scenario",
|
18
|
-
"State",
|
19
|
-
"Visualizer",
|
20
|
-
"Skin",
|
21
|
-
"bullet_fn",
|
6
|
+
"aid",
|
7
|
+
"types",
|
22
8
|
]
|
parabellum/aid.py
CHANGED
@@ -3,10 +3,9 @@
|
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# imports
|
6
|
-
import os
|
7
6
|
from collections import namedtuple
|
8
|
-
from typing import Tuple
|
9
7
|
import cartopy.crs as ccrs
|
8
|
+
import jax.numpy as jnp
|
10
9
|
|
11
10
|
# types
|
12
11
|
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
@@ -23,5 +22,20 @@ def to_mercator(bbox: BBox) -> BBox:
|
|
23
22
|
def to_platecarree(bbox: BBox) -> BBox:
|
24
23
|
proj = ccrs.PlateCarree()
|
25
24
|
west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
|
25
|
+
|
26
26
|
east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
|
27
27
|
return BBox(north=north, south=south, east=east, west=west)
|
28
|
+
|
29
|
+
|
30
|
+
def obstacle_mask_fn(limit):
|
31
|
+
def aux(i, j):
|
32
|
+
xs = jnp.linspace(0, i + 1, i + j + 1)
|
33
|
+
ys = jnp.linspace(0, j + 1, i + j + 1)
|
34
|
+
cc = jnp.stack((xs, ys)).astype(jnp.int8)
|
35
|
+
mask = jnp.zeros((limit, limit)).at[*cc].set(1)
|
36
|
+
return mask
|
37
|
+
|
38
|
+
x = jnp.repeat(jnp.arange(limit), limit)
|
39
|
+
y = jnp.tile(jnp.arange(limit), limit)
|
40
|
+
mask = jnp.stack([aux(*c) for c in jnp.stack((x, y)).T])
|
41
|
+
return mask.astype(jnp.int8).reshape(limit, limit, limit, limit)
|
parabellum/env.py
CHANGED
@@ -1,570 +1,113 @@
|
|
1
|
-
|
1
|
+
# env.py
|
2
|
+
# parabellum env
|
3
|
+
# by: Noah Syrkis
|
2
4
|
|
5
|
+
# % Imports
|
3
6
|
import jax.numpy as jnp
|
4
|
-
import
|
5
|
-
|
6
|
-
from
|
7
|
-
import chex
|
8
|
-
from jaxmarl.environments.smax.smax_env import SMAX
|
9
|
-
|
10
|
-
from math import ceil
|
11
|
-
|
12
|
-
from typing import Tuple, Dict, cast
|
7
|
+
from jax import random, Array, lax, vmap, debug
|
8
|
+
import jax.numpy.linalg as la
|
9
|
+
from typing import Tuple
|
13
10
|
from functools import partial
|
14
|
-
from parabellum import tps, geo, terrain_db
|
15
|
-
|
16
|
-
|
17
|
-
@dataclass
|
18
|
-
class Scenario:
|
19
|
-
"""Parabellum scenario"""
|
20
|
-
|
21
|
-
place: str
|
22
|
-
terrain: tps.Terrain
|
23
|
-
unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
|
24
|
-
unit_types: chex.Array
|
25
|
-
num_allies: int
|
26
|
-
num_enemies: int
|
27
|
-
|
28
|
-
smacv2_position_generation: bool = False
|
29
|
-
smacv2_unit_type_generation: bool = False
|
30
|
-
|
31
|
-
|
32
|
-
@dataclass
|
33
|
-
class State:
|
34
|
-
# terrain: Array
|
35
|
-
unit_positions: Array # fsfds
|
36
|
-
unit_alive: Array
|
37
|
-
unit_teams: Array
|
38
|
-
unit_health: Array
|
39
|
-
unit_types: Array
|
40
|
-
unit_weapon_cooldowns: Array
|
41
|
-
prev_movement_actions: Array
|
42
|
-
prev_attack_actions: Array
|
43
|
-
time: int
|
44
|
-
terminal: bool
|
45
|
-
|
46
|
-
|
47
|
-
def make_scenario(
|
48
|
-
place,
|
49
|
-
size,
|
50
|
-
unit_starting_sectors,
|
51
|
-
allies_type,
|
52
|
-
n_allies,
|
53
|
-
enemies_type,
|
54
|
-
n_enemies,
|
55
|
-
):
|
56
|
-
if place in terrain_db.db:
|
57
|
-
terrain = terrain_db.make_terrain(terrain_db.db[place], size)
|
58
|
-
else:
|
59
|
-
terrain = geo.geography_fn(place, size)
|
60
|
-
if type(unit_starting_sectors) == list:
|
61
|
-
default_sector = [
|
62
|
-
0,
|
63
|
-
0,
|
64
|
-
size,
|
65
|
-
size,
|
66
|
-
] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
|
67
|
-
correct_unit_starting_sectors = []
|
68
|
-
for i in range(n_allies + n_enemies):
|
69
|
-
selected_sector = None
|
70
|
-
for unit_ids, sector in unit_starting_sectors:
|
71
|
-
if i in unit_ids:
|
72
|
-
selected_sector = sector
|
73
|
-
if selected_sector is None:
|
74
|
-
selected_sector = default_sector
|
75
|
-
correct_unit_starting_sectors.append(selected_sector)
|
76
|
-
unit_starting_sectors = correct_unit_starting_sectors
|
77
|
-
if type(allies_type) == int:
|
78
|
-
allies = [allies_type] * n_allies
|
79
|
-
else:
|
80
|
-
assert len(allies_type) == n_allies
|
81
|
-
allies = allies_type
|
82
|
-
|
83
|
-
if type(enemies_type) == int:
|
84
|
-
enemies = [enemies_type] * n_enemies
|
85
|
-
else:
|
86
|
-
assert len(enemies_type) == n_enemies
|
87
|
-
enemies = enemies_type
|
88
|
-
unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
|
89
|
-
return Scenario(
|
90
|
-
place,
|
91
|
-
terrain,
|
92
|
-
unit_starting_sectors, # type: ignore
|
93
|
-
unit_types,
|
94
|
-
n_allies,
|
95
|
-
n_enemies,
|
96
|
-
)
|
97
|
-
|
98
|
-
|
99
|
-
def scenario_fn(place, size):
|
100
|
-
# scenario function for Noah, cos the one above is confusing
|
101
|
-
terrain = geo.geography_fn(place, size)
|
102
|
-
num_allies = 10
|
103
|
-
num_enemies = 10
|
104
|
-
unit_types = jnp.array([0] * num_allies + [1] * num_enemies, dtype=jnp.uint8)
|
105
|
-
# start units in default sectors
|
106
|
-
unit_starting_sectors = jnp.array([[0, 0, 1, 1]] * (num_allies + num_enemies))
|
107
|
-
return Scenario(
|
108
|
-
place=place,
|
109
|
-
terrain=terrain,
|
110
|
-
unit_starting_sectors=unit_starting_sectors,
|
111
|
-
unit_types=unit_types,
|
112
|
-
num_allies=num_allies,
|
113
|
-
num_enemies=num_enemies,
|
114
|
-
)
|
115
|
-
|
116
|
-
|
117
|
-
def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
|
118
|
-
"""Spawns n agents on a map."""
|
119
|
-
spawn_positions = []
|
120
|
-
for sector in units_spawning_sectors:
|
121
|
-
rng, key_start, key_noise = random.split(rng, 3)
|
122
|
-
noise = 0.25 + random.uniform(key_noise, (2,)) * 0.5
|
123
|
-
idx = random.choice(key_start, sector[0].shape[0])
|
124
|
-
coord = jnp.array([sector[0][idx], sector[1][idx]])
|
125
|
-
spawn_positions.append(coord + noise)
|
126
|
-
return jnp.array(spawn_positions, dtype=jnp.float32)
|
127
|
-
|
128
|
-
|
129
|
-
def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
|
130
|
-
"""
|
131
|
-
sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
|
132
|
-
"""
|
133
|
-
width, height = invalid_spawn_areas.shape
|
134
|
-
spawning_sectors = []
|
135
|
-
for sector in sectors:
|
136
|
-
coordx, coordy = (
|
137
|
-
jnp.array(sector[0] * width, dtype=jnp.int32),
|
138
|
-
jnp.array(sector[1] * height, dtype=jnp.int32),
|
139
|
-
)
|
140
|
-
sector = (
|
141
|
-
invalid_spawn_areas[
|
142
|
-
coordx : coordx + ceil(sector[2] * width),
|
143
|
-
coordy : coordy + ceil(sector[3] * height),
|
144
|
-
]
|
145
|
-
== 0
|
146
|
-
)
|
147
|
-
valid = jnp.nonzero(sector)
|
148
|
-
if valid[0].shape[0] == 0:
|
149
|
-
raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
|
150
|
-
spawning_sectors.append(
|
151
|
-
jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1))
|
152
|
-
)
|
153
|
-
return spawning_sectors
|
154
|
-
|
155
|
-
|
156
|
-
class Environment(SMAX):
|
157
|
-
def __init__(self, scenario: Scenario, **kwargs):
|
158
|
-
map_height, map_width = scenario.terrain.building.shape
|
159
|
-
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
160
|
-
if "unit_type_pushable" in kwargs:
|
161
|
-
self.unit_type_pushable = kwargs["unit_type_pushable"]
|
162
|
-
del kwargs["unit_type_pushable"]
|
163
|
-
else:
|
164
|
-
self.unit_type_pushable = jnp.array([1, 1, 0, 0, 0, 1])
|
165
|
-
if "reset_when_done" in kwargs:
|
166
|
-
self.reset_when_done = kwargs["reset_when_done"]
|
167
|
-
del kwargs["reset_when_done"]
|
168
|
-
else:
|
169
|
-
self.reset_when_done = True
|
170
|
-
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
171
|
-
self.terrain = scenario.terrain
|
172
|
-
self.unit_starting_sectors = scenario.unit_starting_sectors
|
173
|
-
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
174
|
-
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
175
|
-
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
176
|
-
self.scenario = scenario
|
177
|
-
self.unit_type_velocities = (
|
178
|
-
jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15]) / 2.5
|
179
|
-
if "unit_type_velocities" not in kwargs
|
180
|
-
else kwargs["unit_type_velocities"]
|
181
|
-
)
|
182
|
-
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
183
|
-
self.max_steps = 200
|
184
|
-
self._push_units_away = lambda state, firmness=1: state # overwrite push units
|
185
|
-
self.spawning_sectors = sectors_fn(
|
186
|
-
self.unit_starting_sectors,
|
187
|
-
scenario.terrain.building + scenario.terrain.water,
|
188
|
-
)
|
189
|
-
self.resolution = (
|
190
|
-
jnp.array(jnp.max(self.unit_type_sight_ranges), dtype=jnp.int32) * 2
|
191
|
-
)
|
192
|
-
self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, 1))
|
193
|
-
|
194
|
-
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
|
195
|
-
"""Environment-specific reset."""
|
196
|
-
unit_positions = spawn_fn(rng, self.spawning_sectors)
|
197
|
-
unit_teams = jnp.zeros((self.num_agents,))
|
198
|
-
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
199
|
-
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
200
|
-
# default behaviour spawn all marines
|
201
|
-
unit_types = cast(Array, self.scenario.unit_types)
|
202
|
-
unit_health = self.unit_type_health[unit_types]
|
203
|
-
state = State(
|
204
|
-
unit_positions=unit_positions,
|
205
|
-
unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
|
206
|
-
unit_teams=unit_teams,
|
207
|
-
unit_health=unit_health,
|
208
|
-
unit_types=unit_types,
|
209
|
-
prev_movement_actions=jnp.zeros((self.num_agents, 2)),
|
210
|
-
prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
|
211
|
-
time=0,
|
212
|
-
terminal=False,
|
213
|
-
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
214
|
-
# terrain=self.terrain,
|
215
|
-
)
|
216
|
-
state = self._push_units_away(state) # type: ignore could be slow
|
217
|
-
obs = self.get_obs(state)
|
218
|
-
# remove world_state from obs
|
219
|
-
world_state = self.get_world_state(state)
|
220
|
-
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
221
|
-
return obs, state
|
222
|
-
|
223
|
-
# def step_env(self, rng, state: State, action: Array): # type: ignore
|
224
|
-
# obs, state, rewards, dones, infos = super().step_env(rng, state, action)
|
225
|
-
# delete world_state from obs
|
226
|
-
<<<<<<< HEAD
|
227
|
-
# obs.pop("world_state")
|
228
|
-
# if not self.reset_when_done:
|
229
|
-
# for key in dones.keys():
|
230
|
-
# dones[key] = False
|
231
|
-
# return obs, state, rewards, dones, infos
|
232
|
-
=======
|
233
|
-
obs.pop("world_state")
|
234
|
-
if not self.reset_when_done:
|
235
|
-
for key in dones.keys():
|
236
|
-
infos[key] = dones[key]
|
237
|
-
dones[key] = False
|
238
|
-
return obs, state, rewards, dones, infos
|
239
|
-
>>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
|
240
|
-
|
241
|
-
def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
|
242
|
-
"""Applies observation function to state."""
|
243
|
-
|
244
|
-
def get_features(i, j):
|
245
|
-
"""Get features of unit j as seen from unit i"""
|
246
|
-
# Can just keep them symmetrical for now.
|
247
|
-
# j here means 'the jth unit that is not i'
|
248
|
-
# The observation is such that allies are always first
|
249
|
-
# so for units in the second team we count in reverse.
|
250
|
-
j = jax.lax.cond(
|
251
|
-
i < self.num_allies, lambda: j, lambda: self.num_agents - j - 1
|
252
|
-
)
|
253
|
-
offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
|
254
|
-
j_idx = jax.lax.cond(
|
255
|
-
((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
|
256
|
-
lambda: j,
|
257
|
-
lambda: j + offset,
|
258
|
-
)
|
259
|
-
empty_features = jnp.zeros(shape=(len(self.unit_features),))
|
260
|
-
features = self._observe_features(state, i, j_idx)
|
261
|
-
visible = (
|
262
|
-
jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
|
263
|
-
< self.unit_type_sight_ranges[state.unit_types[i]]
|
264
|
-
)
|
265
|
-
return jax.lax.cond(
|
266
|
-
visible
|
267
|
-
& state.unit_alive[i]
|
268
|
-
& state.unit_alive[j_idx]
|
269
|
-
& self.has_line_of_sight(
|
270
|
-
state.unit_positions[j_idx],
|
271
|
-
state.unit_positions[i],
|
272
|
-
self.terrain.building + self.terrain.forest,
|
273
|
-
),
|
274
|
-
lambda: features,
|
275
|
-
lambda: empty_features,
|
276
|
-
)
|
277
|
-
|
278
|
-
get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
|
279
|
-
get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
|
280
|
-
other_unit_obs = get_all_features(
|
281
|
-
jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
|
282
|
-
)
|
283
|
-
other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
|
284
|
-
get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
|
285
|
-
own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
|
286
|
-
obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
|
287
|
-
return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
|
288
|
-
|
289
|
-
def _our_push_units_away(
|
290
|
-
self, pos, unit_types, firmness: float = 1.0
|
291
|
-
): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
292
|
-
delta_matrix = pos[:, None] - pos[None, :]
|
293
|
-
dist_matrix = (
|
294
|
-
jnp.linalg.norm(delta_matrix, axis=-1)
|
295
|
-
+ jnp.identity(self.num_agents)
|
296
|
-
+ 1e-6
|
297
|
-
)
|
298
|
-
radius_matrix = (
|
299
|
-
self.unit_type_radiuses[unit_types][:, None]
|
300
|
-
+ self.unit_type_radiuses[unit_types][None, :]
|
301
|
-
)
|
302
|
-
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
303
|
-
unit_positions = (
|
304
|
-
pos
|
305
|
-
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
306
|
-
)
|
307
|
-
return jnp.where(
|
308
|
-
self.unit_type_pushable[unit_types][:, None], unit_positions, pos
|
309
|
-
)
|
310
|
-
|
311
|
-
def has_line_of_sight(self, source, target, raster_input):
|
312
|
-
# suppose that the target is in sight_range of source, otherwise the line of sight might miss some cells
|
313
|
-
cells = jnp.array(
|
314
|
-
source[:, jnp.newaxis] * self.t + (1 - self.t) * target[:, jnp.newaxis],
|
315
|
-
dtype=jnp.int32,
|
316
|
-
)
|
317
|
-
mask = jnp.zeros(raster_input.shape).at[cells[0, :], cells[1, :]].set(1)
|
318
|
-
flag = ~jnp.any(jnp.logical_and(mask, raster_input))
|
319
|
-
return flag
|
320
|
-
|
321
|
-
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
322
|
-
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
323
|
-
self,
|
324
|
-
key: chex.PRNGKey,
|
325
|
-
state: State,
|
326
|
-
actions: Tuple[chex.Array, chex.Array],
|
327
|
-
) -> State:
|
328
|
-
def raster_crossing(pos, new_pos, mask: jnp.ndarray):
|
329
|
-
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
330
|
-
minimum = jnp.minimum(pos, new_pos)
|
331
|
-
maximum = jnp.maximum(pos, new_pos)
|
332
|
-
mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask.T, 0).T
|
333
|
-
mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask.T, 0).T
|
334
|
-
mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask, 0)
|
335
|
-
mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask, 0)
|
336
|
-
return jnp.any(mask)
|
337
|
-
|
338
|
-
def update_position(idx, vec):
|
339
|
-
# Compute the movements slightly strangely.
|
340
|
-
# The velocities below are for diagonal directions
|
341
|
-
# because these are easier to encode as actions than the four
|
342
|
-
# diagonal directions. Then rotate the velocity 45
|
343
|
-
# degrees anticlockwise to compute the movement.
|
344
|
-
pos = cast(Array, state.unit_positions[idx])
|
345
|
-
new_pos = (
|
346
|
-
pos
|
347
|
-
+ vec
|
348
|
-
* self.unit_type_velocities[state.unit_types[idx]]
|
349
|
-
* self.time_per_step
|
350
|
-
)
|
351
|
-
# avoid going out of bounds
|
352
|
-
new_pos = jnp.maximum(
|
353
|
-
jnp.minimum(
|
354
|
-
new_pos, jnp.array([self.map_width - 1, self.map_height - 1])
|
355
|
-
),
|
356
|
-
jnp.zeros((2,)),
|
357
|
-
)
|
358
|
-
|
359
|
-
#######################################################################
|
360
|
-
############################################ avoid going into obstacles
|
361
|
-
clash = raster_crossing(
|
362
|
-
pos, new_pos, self.terrain.building + self.terrain.water
|
363
|
-
)
|
364
|
-
new_pos = jnp.where(clash, pos, new_pos)
|
365
|
-
|
366
|
-
#######################################################################
|
367
|
-
#######################################################################
|
368
|
-
|
369
|
-
return new_pos
|
370
|
-
|
371
|
-
#######################################################################
|
372
|
-
######################################### units close enough to get hit
|
373
|
-
|
374
|
-
def bystander_fn(attacked_idx):
|
375
|
-
idxs = jnp.zeros((self.num_agents,))
|
376
|
-
idxs *= (
|
377
|
-
jnp.linalg.norm(
|
378
|
-
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
379
|
-
)
|
380
|
-
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
381
|
-
)
|
382
|
-
return idxs
|
383
|
-
|
384
|
-
#######################################################################
|
385
|
-
#######################################################################
|
386
|
-
|
387
|
-
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
388
|
-
# for team 1, their attack actions are labelled in
|
389
|
-
# reverse order because that is the order they are
|
390
|
-
# observed in
|
391
|
-
attacked_idx = jax.lax.cond(
|
392
|
-
idx < self.num_allies,
|
393
|
-
lambda: action + self.num_allies - self.num_movement_actions,
|
394
|
-
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
395
|
-
)
|
396
|
-
attacked_idx = cast(int, attacked_idx) # Cast to int
|
397
|
-
# deal with no-op attack actions (i.e. agents that are moving instead)
|
398
|
-
attacked_idx = jax.lax.select(
|
399
|
-
action < self.num_movement_actions, idx, attacked_idx
|
400
|
-
)
|
401
|
-
distance = jnp.linalg.norm(
|
402
|
-
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
403
|
-
)
|
404
|
-
attack_valid = (
|
405
|
-
(distance <= self.unit_type_attack_ranges[state.unit_types[idx]])
|
406
|
-
& state.unit_alive[idx]
|
407
|
-
& state.unit_alive[attacked_idx]
|
408
|
-
)
|
409
|
-
attack_valid = attack_valid & (idx != attacked_idx)
|
410
|
-
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
411
|
-
health_diff = jax.lax.select(
|
412
|
-
attack_valid,
|
413
|
-
-self.unit_type_attacks[state.unit_types[idx]],
|
414
|
-
0.0,
|
415
|
-
)
|
416
|
-
health_diff = jnp.where(
|
417
|
-
state.unit_types[idx] == 1,
|
418
|
-
health_diff
|
419
|
-
* distance
|
420
|
-
/ self.unit_type_attack_ranges[state.unit_types[idx]],
|
421
|
-
health_diff,
|
422
|
-
)
|
423
|
-
# design choice based on the pysc2 randomness details.
|
424
|
-
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
425
|
-
|
426
|
-
#########################################################
|
427
|
-
############################### Add bystander health diff
|
428
|
-
|
429
|
-
# bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
430
|
-
# bystander_valid = (
|
431
|
-
# jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
432
|
-
# .astype(jnp.bool_) # type: ignore
|
433
|
-
# .astype(jnp.float32)
|
434
|
-
# )
|
435
|
-
# bystander_health_diff = (
|
436
|
-
# bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
437
|
-
# )
|
438
|
-
|
439
|
-
#########################################################
|
440
|
-
#########################################################
|
441
|
-
|
442
|
-
cooldown_deviation = jax.random.uniform(
|
443
|
-
key, minval=-self.time_per_step, maxval=2 * self.time_per_step
|
444
|
-
)
|
445
|
-
cooldown = (
|
446
|
-
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
447
|
-
+ cooldown_deviation
|
448
|
-
)
|
449
|
-
cooldown_diff = jax.lax.select(
|
450
|
-
attack_valid,
|
451
|
-
# subtract the current cooldown because we are
|
452
|
-
# going to add it back. This way we effectively
|
453
|
-
# set the new cooldown to `cooldown`
|
454
|
-
cooldown - state.unit_weapon_cooldowns[idx],
|
455
|
-
-self.time_per_step,
|
456
|
-
)
|
457
|
-
return (
|
458
|
-
health_diff,
|
459
|
-
attacked_idx,
|
460
|
-
cooldown_diff,
|
461
|
-
# (bystander_health_diff, bystander_idxs),
|
462
|
-
)
|
463
|
-
|
464
|
-
def perform_agent_action(idx, action, key):
|
465
|
-
movement_action, attack_action = action
|
466
|
-
new_pos = update_position(idx, movement_action)
|
467
|
-
health_diff, attacked_idxes, cooldown_diff = update_agent_health(
|
468
|
-
idx, attack_action, key
|
469
|
-
)
|
470
|
-
|
471
|
-
return new_pos, (health_diff, attacked_idxes), cooldown_diff
|
472
|
-
|
473
|
-
keys = jax.random.split(key, num=self.num_agents)
|
474
|
-
pos, (health_diff, attacked_idxes), cooldown_diff = jax.vmap(
|
475
|
-
perform_agent_action
|
476
|
-
)(jnp.arange(self.num_agents), actions, keys)
|
477
|
-
|
478
|
-
# units push each other
|
479
|
-
new_pos = self._our_push_units_away(pos, state.unit_types)
|
480
|
-
clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(
|
481
|
-
pos, new_pos, self.terrain.building + self.terrain.water
|
482
|
-
)
|
483
|
-
pos = jax.vmap(jnp.where)(clash, pos, new_pos)
|
484
|
-
# avoid going out of bounds
|
485
|
-
pos = jnp.maximum(
|
486
|
-
jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
|
487
|
-
jnp.zeros((2,)),
|
488
|
-
)
|
489
|
-
|
490
|
-
# Multiple enemies can attack the same unit.
|
491
|
-
# We have `(health_diff, attacked_idx)` pairs.
|
492
|
-
# `jax.lax.scatter_add` aggregates these exactly
|
493
|
-
# in the way we want -- duplicate idxes will have their
|
494
|
-
# health differences added together. However, it is a
|
495
|
-
# super thin wrapper around the XLA scatter operation,
|
496
|
-
# which has this bonkers syntax and requires this dnums
|
497
|
-
# parameter. The usage here was inferred from a test:
|
498
|
-
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
499
|
-
dnums = jax.lax.ScatterDimensionNumbers(
|
500
|
-
update_window_dims=(),
|
501
|
-
inserted_window_dims=(0,),
|
502
|
-
scatter_dims_to_operand_dims=(0,),
|
503
|
-
)
|
504
|
-
unit_health = jnp.maximum(
|
505
|
-
jax.lax.scatter_add(
|
506
|
-
state.unit_health,
|
507
|
-
jnp.expand_dims(attacked_idxes, 1),
|
508
|
-
health_diff,
|
509
|
-
dnums,
|
510
|
-
),
|
511
|
-
0.0,
|
512
|
-
)
|
513
|
-
|
514
|
-
#########################################################
|
515
|
-
############################ subtracting bystander health
|
516
|
-
|
517
|
-
# _, bystander_health_diff = bystander
|
518
|
-
# unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
519
|
-
|
520
|
-
#########################################################
|
521
|
-
#########################################################
|
522
|
-
|
523
|
-
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
524
|
-
# replace unit health, unit positions and unit weapon cooldowns
|
525
|
-
state = state.replace( # type: ignore
|
526
|
-
unit_health=unit_health,
|
527
|
-
unit_positions=pos,
|
528
|
-
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
529
|
-
)
|
530
|
-
return state
|
531
|
-
|
532
|
-
|
533
|
-
if __name__ == "__main__":
|
534
|
-
n_envs = 4
|
535
|
-
|
536
|
-
n_allies = 10
|
537
|
-
scenario_kwargs = {
|
538
|
-
"allies_type": 0,
|
539
|
-
"n_allies": n_allies,
|
540
|
-
"enemies_type": 0,
|
541
|
-
"n_enemies": n_allies,
|
542
|
-
"place": "Vesterbro, Copenhagen, Denmark",
|
543
|
-
"size": 100,
|
544
|
-
"unit_starting_sectors": [
|
545
|
-
([i for i in range(n_allies)], [0.0, 0.45, 0.1, 0.1]),
|
546
|
-
([n_allies + i for i in range(n_allies)], [0.8, 0.5, 0.1, 0.1]),
|
547
|
-
],
|
548
|
-
}
|
549
|
-
scenario = make_scenario(**scenario_kwargs)
|
550
|
-
env = Environment(scenario)
|
551
|
-
rng, reset_rng = random.split(random.PRNGKey(0))
|
552
|
-
reset_key = random.split(reset_rng, n_envs)
|
553
|
-
obs, state = vmap(env.reset)(reset_key)
|
554
|
-
state_seq = []
|
555
|
-
|
556
|
-
import time
|
557
11
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
12
|
+
from parabellum.geo import geography_fn
|
13
|
+
from parabellum.types import Action, State, Obs, Scene
|
14
|
+
from parabellum import aid
|
15
|
+
import equinox as eqx
|
16
|
+
|
17
|
+
|
18
|
+
# %% Dataclass ################################################################
|
19
|
+
class Env:
|
20
|
+
def __init__(self, cfg):
|
21
|
+
self.cfg = cfg
|
22
|
+
|
23
|
+
def reset(self, rng: Array, scene: Scene) -> Tuple[Obs, State]:
|
24
|
+
return init_fn(rng, self, scene)
|
25
|
+
|
26
|
+
def step(self, rng: Array, scene: Scene, state: State, action: Action) -> Tuple[Obs, State]:
|
27
|
+
return obs_fn(self, scene, state), step_fn(rng, self, scene, state, action)
|
28
|
+
|
29
|
+
@property
|
30
|
+
def num_units(self):
|
31
|
+
return sum(self.cfg.counts.allies.values()) + sum(self.cfg.counts.enemies.values())
|
32
|
+
|
33
|
+
@property
|
34
|
+
def num_allies(self):
|
35
|
+
return sum(self.cfg.counts.allies.values())
|
36
|
+
|
37
|
+
@property
|
38
|
+
def num_enemies(self):
|
39
|
+
return sum(self.cfg.counts.enemies.values())
|
40
|
+
|
41
|
+
|
42
|
+
# %% Functions ################################################################
|
43
|
+
@eqx.filter_jit
|
44
|
+
def init_fn(rng: Array, env: Env, scene: Scene) -> Tuple[Obs, State]: # initialize -----
|
45
|
+
keys = random.split(rng)
|
46
|
+
health = jnp.ones(env.num_units) * scene.unit_type_health[scene.unit_types]
|
47
|
+
pos = random.normal(keys[1], (scene.unit_types.size, 2)) * 2 + env.cfg.size / 2
|
48
|
+
state = State(unit_position=pos, unit_health=health, unit_cooldown=jnp.zeros(env.num_units)) # state --
|
49
|
+
return obs_fn(env, scene, state), state # return observation and state of agents --
|
50
|
+
|
51
|
+
|
52
|
+
@eqx.filter_jit # knn from env.cfg never changes, so we can jit it
|
53
|
+
def obs_fn(env, scene: Scene, state: State) -> Obs: # return info about neighbors ---
|
54
|
+
distances = la.norm(state.unit_position[:, None] - state.unit_position, axis=-1) # all dist --
|
55
|
+
dists, idxs = lax.approx_min_k(distances, k=env.cfg.knn)
|
56
|
+
mask = mask_fn(scene, state, dists, idxs)
|
57
|
+
health = state.unit_health[idxs] * mask
|
58
|
+
cooldown = state.unit_cooldown[idxs] * mask
|
59
|
+
unit_pos = (state.unit_position[:, None, ...] - state.unit_position[idxs]) * mask[..., None]
|
60
|
+
return Obs(unit_id=idxs, unit_pos=unit_pos, unit_health=health, unit_cooldown=cooldown)
|
61
|
+
|
62
|
+
|
63
|
+
@eqx.filter_jit
|
64
|
+
def step_fn(rng, env: Env, scene: Scene, state: State, action: Action) -> State: # update agents ---
|
65
|
+
newpos = state.unit_position + action.coord * (1 - action.kinds[..., None])
|
66
|
+
bounds = ((newpos < 0).any(axis=-1) | (newpos >= env.cfg.size).any(axis=-1))[..., None]
|
67
|
+
builds = (scene.terrain.building[*newpos.astype(jnp.int32).T] > 0)[..., None]
|
68
|
+
newpos = jnp.where(bounds | builds, state.unit_position, newpos) # use old pos if new is not valid
|
69
|
+
new_hp = blast_fn(rng, env, scene, state, action)
|
70
|
+
return State(unit_position=newpos, unit_health=new_hp, unit_cooldown=state.unit_cooldown) # return -
|
71
|
+
|
72
|
+
|
73
|
+
def blast_fn(rng, env: Env, scene: Scene, state: State, action: Action): # update agents ---
|
74
|
+
dist = la.norm(state.unit_position[None, ...] - (state.unit_position + action.coord)[:, None, ...], axis=-1)
|
75
|
+
hits = dist <= scene.unit_type_reach[scene.unit_types][None, ...] * action.kinds[..., None] # mask non attack act
|
76
|
+
damage = (hits * scene.unit_type_damage[scene.unit_types][None, ...]).sum(axis=-1)
|
77
|
+
return state.unit_health - damage
|
78
|
+
|
79
|
+
|
80
|
+
# @eqx.filter_jit
|
81
|
+
def scene_fn(cfg): # init's a scene
|
82
|
+
aux = lambda key: jnp.array([x[key] for x in sorted(cfg.types, key=lambda x: x.name)]) # noqa
|
83
|
+
attrs = ["health", "damage", "reload", "reach", "sight", "speed"]
|
84
|
+
kwargs = {f"unit_type_{a}": aux(a) for a in attrs} | {"terrain": geography_fn(cfg.place, cfg.size)}
|
85
|
+
num_allies, num_enemies = sum(cfg.counts.allies.values()), sum(cfg.counts.enemies.values())
|
86
|
+
unit_teams = jnp.concat((jnp.zeros(num_allies), jnp.ones(num_enemies))).astype(jnp.int32)
|
87
|
+
aux = lambda t: jnp.concat([jnp.zeros(x) + i for i, x in enumerate([x[1] for x in sorted(cfg.counts[t].items())])]) # noqa
|
88
|
+
unit_types = jnp.concat((aux("allies"), aux("enemies"))).astype(jnp.int32)
|
89
|
+
mask = aid.obstacle_mask_fn(max([x["sight"] for x in cfg.types]))
|
90
|
+
return Scene(unit_teams=unit_teams, unit_types=unit_types, mask=mask, **kwargs) # type: ignore
|
91
|
+
|
92
|
+
|
93
|
+
@eqx.filter_jit
|
94
|
+
def mask_fn(scene, state, dists, idxs):
|
95
|
+
mask = dists < scene.unit_type_sight[scene.unit_types][..., None] # mask for removing hidden
|
96
|
+
mask = mask | obstacle_fn(scene, state.unit_position[idxs].astype(jnp.int8))
|
97
|
+
return mask
|
98
|
+
|
99
|
+
|
100
|
+
@partial(vmap, in_axes=(None, 0)) # 5 x 2 # not the best name for a fn
|
101
|
+
def obstacle_fn(scene, pos):
|
102
|
+
slice = slice_fn(scene, pos[0], pos)
|
103
|
+
return slice
|
104
|
+
|
105
|
+
|
106
|
+
@partial(vmap, in_axes=(None, None, 0))
|
107
|
+
def slice_fn(scene, source, target): # returns a 10 x 10 view with unit at top left corner, and terrain downwards
|
108
|
+
delta = ((source - target) >= 0) * 2 - 1
|
109
|
+
coord = jnp.sort(jnp.stack((source, source + delta * 10)), axis=0)[0]
|
110
|
+
slice = lax.dynamic_slice(scene.terrain.building, coord, (scene.mask.shape[-1], scene.mask.shape[-1]))
|
111
|
+
slice = lax.cond(delta[0] == 1, lambda: jnp.flip(slice), lambda: slice)
|
112
|
+
slice = lax.cond(delta[1] == 1, lambda: jnp.flip(slice, axis=1), lambda: slice)
|
113
|
+
return (scene.mask[*jnp.abs(source - target)] * slice).sum() == 0
|
parabellum/geo.py
CHANGED
@@ -3,25 +3,23 @@
|
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# %% Imports
|
6
|
-
from parabellum import tps
|
7
|
-
import rasterio
|
8
6
|
from rasterio import features, transform
|
7
|
+
from jax import tree
|
9
8
|
from geopy.geocoders import Nominatim
|
10
9
|
from geopy.distance import distance
|
11
10
|
import contextily as cx
|
12
11
|
import jax.numpy as jnp
|
13
12
|
import cartopy.crs as ccrs
|
14
13
|
from jaxtyping import Array
|
15
|
-
import numpy as np
|
16
14
|
from shapely import box
|
17
15
|
import osmnx as ox
|
18
16
|
import geopandas as gpd
|
19
17
|
from collections import namedtuple
|
20
18
|
from typing import Tuple
|
21
19
|
import matplotlib.pyplot as plt
|
22
|
-
|
23
|
-
import os
|
20
|
+
from cachier import cachier
|
24
21
|
from jax.scipy.signal import convolve
|
22
|
+
from parabellum.types import Terrain
|
25
23
|
|
26
24
|
# %% Types
|
27
25
|
Coords = Tuple[float, float]
|
@@ -79,27 +77,23 @@ def basemap_fn(bbox: BBox, gdf) -> Array:
|
|
79
77
|
return image
|
80
78
|
|
81
79
|
|
82
|
-
|
80
|
+
@cachier()
|
81
|
+
def geography_fn(place, buffer) -> Terrain:
|
83
82
|
bbox = get_bbox(place, buffer)
|
84
83
|
map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
|
85
84
|
gdf = gpd.GeoDataFrame(map_data)
|
86
|
-
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
|
87
|
-
"EPSG:3857"
|
88
|
-
)
|
85
|
+
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
|
89
86
|
raster = raster_fn(gdf, shape=(buffer, buffer))
|
90
87
|
basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
|
91
|
-
# 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
|
92
88
|
kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
93
|
-
trans = lambda x: jnp.rot90(x, 3)
|
94
|
-
|
95
|
-
terrain = tps.Terrain(
|
89
|
+
trans = lambda x: jnp.rot90(x, 3) # noqa
|
90
|
+
terrain = Terrain(
|
96
91
|
building=trans(raster[0]),
|
97
|
-
water=trans(
|
98
|
-
raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
|
99
|
-
),
|
92
|
+
water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
|
100
93
|
forest=trans(jnp.logical_or(raster[3], raster[4])),
|
101
94
|
basemap=basemap,
|
102
95
|
)
|
96
|
+
terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
|
103
97
|
return terrain
|
104
98
|
|
105
99
|
|
@@ -128,25 +122,25 @@ def feature_fn(t, feature, gdf, shape):
|
|
128
122
|
|
129
123
|
|
130
124
|
# %%
|
131
|
-
def normalize(x):
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
def get_bridges(gdf):
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
125
|
+
# def normalize(x):
|
126
|
+
# return (np.array(x) - m) / (M - m)
|
127
|
+
|
128
|
+
|
129
|
+
# def get_bridges(gdf):
|
130
|
+
# xmin, ymin, xmax, ymax = gdf.total_bounds
|
131
|
+
# m = np.array([xmin, ymin])
|
132
|
+
# M = np.array([xmax, ymax])
|
133
|
+
|
134
|
+
# bridges = {}
|
135
|
+
# for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
|
136
|
+
# if type(bridge["name"]) == str:
|
137
|
+
# bridges[idx[1]] = {
|
138
|
+
# "name": bridge["name"],
|
139
|
+
# "coords": normalize(
|
140
|
+
# [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
|
141
|
+
# ),
|
142
|
+
# }
|
143
|
+
# return bridges
|
150
144
|
|
151
145
|
|
152
146
|
"""
|
parabellum/model.py
ADDED
parabellum/ppo.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
|
parabellum/terrain_db.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1
1
|
# %%
|
2
2
|
import numpy as np
|
3
3
|
import jax.numpy as jnp
|
4
|
-
from parabellum import
|
4
|
+
from parabellum.types import Terrain
|
5
5
|
|
6
6
|
|
7
7
|
# %%
|
8
8
|
def map_raster_from_line(raster, line, size):
|
9
9
|
x0, y0, dx, dy = line
|
10
|
-
x0 = int(x0*size)
|
11
|
-
y0 = int(y0*size)
|
12
|
-
dx = int(dx*size)
|
13
|
-
dy = int(dy*size)
|
10
|
+
x0 = int(x0 * size)
|
11
|
+
y0 = int(y0 * size)
|
12
|
+
dx = int(dx * size)
|
13
|
+
dy = int(dy * size)
|
14
14
|
max_T = int(2**0.5 * size)
|
15
|
-
for t in range(max_T+1):
|
16
|
-
alpha = t/float(max_T)
|
17
|
-
x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0+dx))
|
18
|
-
y = y0 if dy == 0 else
|
19
|
-
if 0<=x<size and 0<=y<size:
|
15
|
+
for t in range(max_T + 1):
|
16
|
+
alpha = t / float(max_T)
|
17
|
+
x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0 + dx))
|
18
|
+
y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0 + dy))
|
19
|
+
if 0 <= x < size and 0 <= y < size:
|
20
20
|
raster[x, y] = 1
|
21
21
|
return raster
|
22
22
|
|
@@ -24,20 +24,21 @@ def map_raster_from_line(raster, line, size):
|
|
24
24
|
# %%
|
25
25
|
def map_raster_from_rect(raster, rect, size):
|
26
26
|
x0, y0, dx, dy = rect
|
27
|
-
x0 = int(x0*size)
|
28
|
-
y0 = int(y0*size)
|
29
|
-
dx = int(dx*size)
|
30
|
-
dy = int(dy*size)
|
31
|
-
raster[x0:x0+dx, y0:y0+dy] = 1
|
27
|
+
x0 = int(x0 * size)
|
28
|
+
y0 = int(y0 * size)
|
29
|
+
dx = int(dx * size)
|
30
|
+
dy = int(dy * size)
|
31
|
+
raster[x0 : x0 + dx, y0 : y0 + dy] = 1
|
32
32
|
return raster
|
33
33
|
|
34
34
|
|
35
35
|
# %%
|
36
|
-
building_color = jnp.array([201,199,198, 255])
|
36
|
+
building_color = jnp.array([201, 199, 198, 255])
|
37
37
|
water_color = jnp.array([193, 237, 254, 255])
|
38
|
-
forest_color = jnp.array([197,214,185, 255])
|
38
|
+
forest_color = jnp.array([197, 214, 185, 255])
|
39
39
|
empty_color = jnp.array([255, 255, 255, 255])
|
40
40
|
|
41
|
+
|
41
42
|
def make_terrain(terrain_args, size):
|
42
43
|
args = {}
|
43
44
|
for key, config in terrain_args.items():
|
@@ -49,44 +50,75 @@ def make_terrain(terrain_args, size):
|
|
49
50
|
elif "rect" in elem:
|
50
51
|
raster = map_raster_from_rect(raster, elem["rect"], size)
|
51
52
|
args[key] = jnp.array(raster.T)
|
52
|
-
basemap = jnp.where(
|
53
|
-
|
54
|
-
|
53
|
+
basemap = jnp.where(
|
54
|
+
args["building"][:, :, None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size, size, 1))
|
55
|
+
)
|
56
|
+
basemap = jnp.where(args["water"][:, :, None], jnp.tile(water_color, (size, size, 1)), basemap)
|
57
|
+
basemap = jnp.where(args["forest"][:, :, None], jnp.tile(forest_color, (size, size, 1)), basemap)
|
55
58
|
args["basemap"] = basemap
|
56
|
-
return
|
59
|
+
return Terrain(**args)
|
57
60
|
|
58
61
|
|
59
62
|
# %%
|
60
63
|
db = {
|
61
|
-
"blank": {
|
62
|
-
"F": {
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
64
|
+
"blank": {"building": None, "water": None, "forest": None},
|
65
|
+
"F": {
|
66
|
+
"building": [
|
67
|
+
{"line": [0.25, 0.33, 0.5, 0]},
|
68
|
+
{"line": [0.75, 0.33, 0.0, 0.25]},
|
69
|
+
{"line": [0.50, 0.33, 0.0, 0.25]},
|
70
|
+
],
|
71
|
+
"water": None,
|
72
|
+
"forest": None,
|
73
|
+
},
|
74
|
+
"stronghold": {
|
75
|
+
"building": [
|
76
|
+
{"line": [0.2, 0.275, 0.2, 0.0]},
|
77
|
+
{"line": [0.2, 0.275, 0.0, 0.2]},
|
78
|
+
{"line": [0.4, 0.275, 0.0, 0.2]},
|
79
|
+
{"line": [0.2, 0.475, 0.2, 0.0]},
|
80
|
+
{"line": [0.2, 0.525, 0.2, 0.0]},
|
81
|
+
{"line": [0.2, 0.525, 0.0, 0.2]},
|
82
|
+
{"line": [0.4, 0.525, 0.0, 0.2]},
|
83
|
+
{"line": [0.2, 0.725, 0.525, 0.0]},
|
84
|
+
{"line": [0.75, 0.25, 0.0, 0.2]},
|
85
|
+
{"line": [0.75, 0.55, 0.0, 0.19]},
|
86
|
+
{"line": [0.6, 0.25, 0.15, 0.0]},
|
87
|
+
],
|
88
|
+
"water": None,
|
89
|
+
"forest": None,
|
90
|
+
},
|
91
|
+
"playground": {"building": [{"line": [0.5, 0.5, 0.5, 0.0]}], "water": None, "forest": None},
|
73
92
|
"playground2": {
|
74
|
-
|
75
|
-
"water": [{"rect":[0
|
76
|
-
"forest": [{"rect": [0
|
93
|
+
"building": [],
|
94
|
+
"water": [{"rect": [0.0, 0.8, 0.1, 0.1]}, {"rect": [0.2, 0.8, 0.8, 0.1]}],
|
95
|
+
"forest": [{"rect": [0.0, 0.0, 1.0, 0.2]}],
|
96
|
+
},
|
97
|
+
"triangle": {
|
98
|
+
"building": [{"line": [0.33, 0.0, 0.0, 1.0]}, {"line": [0.66, 0.0, 0.0, 1.0]}],
|
99
|
+
"water": None,
|
100
|
+
"forest": None,
|
77
101
|
},
|
78
|
-
"triangle": {'building': [{"line": [0.33, 0., 0., 1.]}, {"line": [0.66, 0., 0., 1.]}], 'water': None, 'forest': None},
|
79
102
|
"u_shape": {
|
80
|
-
|
81
|
-
"water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
|
82
|
-
"forest": []
|
103
|
+
"building": [],
|
104
|
+
"water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
|
105
|
+
"forest": [],
|
83
106
|
},
|
84
107
|
"bridges": {
|
85
|
-
|
86
|
-
"water": [
|
87
|
-
|
88
|
-
|
89
|
-
|
108
|
+
"building": [],
|
109
|
+
"water": [
|
110
|
+
{"rect": [0.475, 0.0, 0.05, 0.1]},
|
111
|
+
{"rect": [0.475, 0.15, 0.05, 0.575]},
|
112
|
+
{"rect": [0.475, 0.775, 0.05, 1.0]},
|
113
|
+
{"rect": [0.0, 0.475, 0.225, 0.05]},
|
114
|
+
{"rect": [0.275, 0.475, 0.45, 0.05]},
|
115
|
+
{"rect": [0.775, 0.475, 0.23, 0.05]},
|
116
|
+
],
|
117
|
+
"forest": [
|
118
|
+
{"rect": [0.1, 0.625, 0.275, 0.275]},
|
119
|
+
{"rect": [0.725, 0.0, 0.3, 0.275]},
|
120
|
+
],
|
121
|
+
},
|
90
122
|
}
|
91
123
|
|
92
124
|
# %% [raw]
|
@@ -128,7 +160,7 @@ if __name__ == "__main__":
|
|
128
160
|
plt.imshow(jnp.rot90(terrain.basemap))
|
129
161
|
bl = (39.5, 5)
|
130
162
|
tr = (44.5, 10)
|
131
|
-
plt.scatter(bl[0], 49-bl[1])
|
132
|
-
plt.scatter(tr[0], 49-tr[1], marker="+")
|
163
|
+
plt.scatter(bl[0], 49 - bl[1])
|
164
|
+
plt.scatter(tr[0], 49 - tr[1], marker="+")
|
133
165
|
|
134
166
|
# %%
|
parabellum/types.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
# types.py
|
2
|
+
# parabellum types
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
from chex import dataclass
|
7
|
+
from jaxtyping import Array, Bool, Float16
|
8
|
+
|
9
|
+
|
10
|
+
# dataclasses
|
11
|
+
@dataclass
|
12
|
+
class State:
|
13
|
+
unit_position: Array
|
14
|
+
unit_health: Array
|
15
|
+
unit_cooldown: Array
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class Obs:
|
20
|
+
unit_id: Array
|
21
|
+
unit_pos: Array
|
22
|
+
unit_health: Array
|
23
|
+
unit_cooldown: Array
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class Action:
|
28
|
+
coord: Float16[Array, "... 2"] # noqa
|
29
|
+
kinds: Bool[Array, "..."]
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class Terrain:
|
34
|
+
building: Array
|
35
|
+
water: Array
|
36
|
+
forest: Array
|
37
|
+
basemap: Array
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class Scene:
|
42
|
+
terrain: Terrain
|
43
|
+
mask: Array
|
44
|
+
|
45
|
+
unit_types: Array
|
46
|
+
unit_teams: Array
|
47
|
+
|
48
|
+
unit_type_health: Array
|
49
|
+
unit_type_damage: Array
|
50
|
+
unit_type_reload: Array
|
51
|
+
|
52
|
+
unit_type_reach: Array
|
53
|
+
unit_type_sight: Array
|
54
|
+
unit_type_speed: Array
|
@@ -1,39 +1,41 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.3
|
2
2
|
Name: parabellum
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.13
|
4
4
|
Summary: Parabellum environment for parallel warfare simulation
|
5
|
-
Home-page: https://github.com/syrkis/parabellum
|
6
|
-
License: MIT
|
7
|
-
Keywords: warfare,simulation,parallel,environment
|
8
5
|
Author: Noah Syrkis
|
9
6
|
Author-email: desk@syrkis.com
|
10
7
|
Requires-Python: >=3.11,<3.12
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
12
8
|
Classifier: Programming Language :: Python :: 3
|
13
9
|
Classifier: Programming Language :: Python :: 3.11
|
10
|
+
Requires-Dist: brax (>=0.12.1,<0.13.0)
|
11
|
+
Requires-Dist: cachier (>=3.1.2,<4.0.0)
|
14
12
|
Requires-Dist: cartopy (>=0.23.0,<0.24.0)
|
15
13
|
Requires-Dist: contextily (>=1.6.0,<2.0.0)
|
16
|
-
Requires-Dist:
|
14
|
+
Requires-Dist: distrax (>=0.1.5,<0.2.0)
|
17
15
|
Requires-Dist: einops (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: equinox (>=0.11.11,<0.12.0)
|
17
|
+
Requires-Dist: evosax (>=0.1.6,<0.2.0)
|
18
|
+
Requires-Dist: flashbax (>=0.1.2,<0.2.0)
|
19
|
+
Requires-Dist: flax (>=0.10.4,<0.11.0)
|
18
20
|
Requires-Dist: folium (>=0.17.0,<0.18.0)
|
19
21
|
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
22
|
+
Requires-Dist: gymnax (>=0.0.8,<0.0.9)
|
20
23
|
Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
|
21
|
-
Requires-Dist: jax (
|
22
|
-
Requires-Dist:
|
24
|
+
Requires-Dist: jax (>=0.5.0,<0.6.0)
|
25
|
+
Requires-Dist: jax-tqdm (>=0.3.1,<0.4.0)
|
23
26
|
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
24
27
|
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
25
|
-
Requires-Dist:
|
26
|
-
Requires-Dist: numpy (
|
27
|
-
Requires-Dist:
|
28
|
+
Requires-Dist: navix (>=0.7.0,<0.8.0)
|
29
|
+
Requires-Dist: numpy (>=2.2.3,<3.0.0)
|
30
|
+
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
31
|
+
Requires-Dist: optax (>=0.2.4,<0.3.0)
|
28
32
|
Requires-Dist: osmnx (==2.0.0b0)
|
29
33
|
Requires-Dist: pandas (>=2.2.2,<3.0.0)
|
30
34
|
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
31
|
-
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
32
35
|
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
33
|
-
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
34
36
|
Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
|
35
37
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
36
|
-
|
38
|
+
Requires-Dist: wandb (>=0.19.7,<0.20.0)
|
37
39
|
Description-Content-Type: text/markdown
|
38
40
|
|
39
41
|
# Parabellum
|
@@ -0,0 +1,15 @@
|
|
1
|
+
parabellum/__init__.py,sha256=Og0bpKlQtkWCJ1yaQk98LniI1sG3B7GAR7aPMFC-v74,96
|
2
|
+
parabellum/aid.py,sha256=hCp-eDONcloKMNpYni61cE6St_67Lks2ivS5eJjSHFQ,1379
|
3
|
+
parabellum/env.py,sha256=NmDXpjzwun1PsgaHPcIH-i4H8TfkY0Iv1sEy2PKoMnA,5357
|
4
|
+
parabellum/geo.py,sha256=Zv2GCwaYKsfB6CE_cwEHhddU9p9Z5e5lL16iNfETVmw,6181
|
5
|
+
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
6
|
+
parabellum/model.py,sha256=o40jW2vp3Fwxt1KqykZ-qZXs73-H24nA4iiRoetAobA,117
|
7
|
+
parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
|
8
|
+
parabellum/ppo.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
9
|
+
parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
|
10
|
+
parabellum/terrain_db.py,sha256=lPd56Qe4_xuVFJcKsftE9RvNN9nPr69GzBJZIS5nKrY,5207
|
11
|
+
parabellum/types.py,sha256=WwMzSQ5qnRo6rqKTHIGlWPcg57lx6ENDSK5DFr3R7-s,832
|
12
|
+
parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
|
13
|
+
parabellum-0.5.13.dist-info/METADATA,sha256=g_tMs-pztJXVRE__E9VLORyZANegyalSoT7ilMz5oRc,2772
|
14
|
+
parabellum-0.5.13.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
15
|
+
parabellum-0.5.13.dist-info/RECORD,,
|
parabellum/tps.py
DELETED
@@ -1,17 +0,0 @@
|
|
1
|
-
# %%
|
2
|
-
# tps.py
|
3
|
-
# parabellum types and dataclasses
|
4
|
-
# by: Noah Syrkis
|
5
|
-
|
6
|
-
# %% Imports
|
7
|
-
from chex import dataclass
|
8
|
-
from jaxtyping import Array
|
9
|
-
|
10
|
-
|
11
|
-
# %% Dataclasses
|
12
|
-
@dataclass
|
13
|
-
class Terrain:
|
14
|
-
building: Array
|
15
|
-
water: Array
|
16
|
-
forest: Array
|
17
|
-
basemap: Array
|
@@ -1,13 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=hIOLir7wgaf_HU4j8uos7PaCrofqPQcr3FcMlBsZyr8,406
|
2
|
-
parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
|
3
|
-
parabellum/env.py,sha256=2bAuT-8ewIhDvBpQDqJ15FIbYEJLnb2MwLkMAg0_Ofc,22880
|
4
|
-
parabellum/geo.py,sha256=PJs9UevibuokDVb3oJWNHvYHlMYGCxB5OkNSbDj48vI,6198
|
5
|
-
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
6
|
-
parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
|
7
|
-
parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
|
8
|
-
parabellum/terrain_db.py,sha256=5lHzbX94lzkb--cETpraXS42G4T4tKekISpTm4yaYEE,4748
|
9
|
-
parabellum/tps.py,sha256=of-RBdelAbNCHQZd1I22RWmZkwUEh6f161mx0X_G2tE,257
|
10
|
-
parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
|
11
|
-
parabellum-0.3.4.dist-info/METADATA,sha256=Kqdvkav_Z46LX_QtZ7MmWtQFMopn2K3WLTQrN19jSrM,2707
|
12
|
-
parabellum-0.3.4.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
13
|
-
parabellum-0.3.4.dist-info/RECORD,,
|