parabellum 0.3.0__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -2,14 +2,14 @@ from .env import Environment, Scenario, make_scenario, State
2
2
  from .vis import Visualizer, Skin
3
3
  from .gun import bullet_fn
4
4
  from . import vis
5
- from . import map
5
+ from . import terrain_db
6
6
  from . import env
7
7
  from . import tps
8
8
  # from .run import run
9
9
 
10
10
  __all__ = [
11
11
  "env",
12
- "map",
12
+ "terrain_db",
13
13
  "vis",
14
14
  "tps",
15
15
  "Environment",
parabellum/env.py CHANGED
@@ -7,9 +7,11 @@ from flax.struct import dataclass
7
7
  import chex
8
8
  from jaxmarl.environments.smax.smax_env import SMAX
9
9
 
10
+ from math import ceil
11
+
10
12
  from typing import Tuple, Dict, cast
11
13
  from functools import partial
12
- from parabellum import tps, geo
14
+ from parabellum import tps, geo, terrain_db
13
15
 
14
16
 
15
17
  @dataclass
@@ -17,7 +19,7 @@ class Scenario:
17
19
  """Parabellum scenario"""
18
20
 
19
21
  place: str
20
- terrain_raster: tps.Terrain
22
+ terrain: tps.Terrain
21
23
  unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
22
24
  unit_types: chex.Array
23
25
  num_allies: int
@@ -42,7 +44,6 @@ class State:
42
44
  terminal: bool
43
45
 
44
46
 
45
-
46
47
  def make_scenario(
47
48
  place,
48
49
  size,
@@ -52,19 +53,27 @@ def make_scenario(
52
53
  enemies_type,
53
54
  n_enemies,
54
55
  ):
55
- terrain = geo.geography_fn(place, size)
56
+ if place in terrain_db.db:
57
+ terrain = terrain_db.make_terrain(terrain_db.db[place], size)
58
+ else:
59
+ terrain = geo.geography_fn(place, size)
56
60
  if type(unit_starting_sectors) == list:
57
- default_sector = [0, 0, size, size] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
58
- correct_unit_starting_sectors = []
59
- for i in range(n_allies+n_enemies):
60
- selected_sector = None
61
- for unit_ids, sector in unit_starting_sectors:
62
- if i in unit_ids:
63
- selected_sector = sector
64
- if selected_sector is None:
65
- selected_sector = default_sector
66
- correct_unit_starting_sectors.append(selected_sector)
67
- unit_starting_sectors = correct_unit_starting_sectors
61
+ default_sector = [
62
+ 0,
63
+ 0,
64
+ size,
65
+ size,
66
+ ] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
67
+ correct_unit_starting_sectors = []
68
+ for i in range(n_allies + n_enemies):
69
+ selected_sector = None
70
+ for unit_ids, sector in unit_starting_sectors:
71
+ if i in unit_ids:
72
+ selected_sector = sector
73
+ if selected_sector is None:
74
+ selected_sector = default_sector
75
+ correct_unit_starting_sectors.append(selected_sector)
76
+ unit_starting_sectors = correct_unit_starting_sectors
68
77
  if type(allies_type) == int:
69
78
  allies = [allies_type] * n_allies
70
79
  else:
@@ -78,7 +87,30 @@ def make_scenario(
78
87
  enemies = enemies_type
79
88
  unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
80
89
  return Scenario(
81
- place, terrain, unit_starting_sectors, unit_types, n_allies, n_enemies # type: ignore
90
+ place,
91
+ terrain,
92
+ unit_starting_sectors, # type: ignore
93
+ unit_types,
94
+ n_allies,
95
+ n_enemies,
96
+ )
97
+
98
+
99
+ def scenario_fn(place, size):
100
+ # scenario function for Noah, cos the one above is confusing
101
+ terrain = geo.geography_fn(place, size)
102
+ num_allies = 10
103
+ num_enemies = 10
104
+ unit_types = jnp.array([0] * num_allies + [1] * num_enemies, dtype=jnp.uint8)
105
+ # start units in default sectors
106
+ unit_starting_sectors = jnp.array([[0, 0, 1, 1]] * (num_allies + num_enemies))
107
+ return Scenario(
108
+ place=place,
109
+ terrain=terrain,
110
+ unit_starting_sectors=unit_starting_sectors,
111
+ unit_types=unit_types,
112
+ num_allies=num_allies,
113
+ num_enemies=num_enemies,
82
114
  )
83
115
 
84
116
 
@@ -87,7 +119,7 @@ def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
87
119
  spawn_positions = []
88
120
  for sector in units_spawning_sectors:
89
121
  rng, key_start, key_noise = random.split(rng, 3)
90
- noise = random.uniform(key_noise, (2,)) * 0.5
122
+ noise = 0.25 + random.uniform(key_noise, (2,)) * 0.5
91
123
  idx = random.choice(key_start, sector[0].shape[0])
92
124
  coord = jnp.array([sector[0][idx], sector[1][idx]])
93
125
  spawn_positions.append(coord + noise)
@@ -101,47 +133,65 @@ def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
101
133
  width, height = invalid_spawn_areas.shape
102
134
  spawning_sectors = []
103
135
  for sector in sectors:
104
- coordx, coordy = jnp.array(sector[0] * width, dtype=jnp.int32), jnp.array(sector[1] * height, dtype=jnp.int32)
105
- sector = (invalid_spawn_areas[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0)
106
- valid = jnp.nonzero(sector.T)
136
+ coordx, coordy = (
137
+ jnp.array(sector[0] * width, dtype=jnp.int32),
138
+ jnp.array(sector[1] * height, dtype=jnp.int32),
139
+ )
140
+ sector = (
141
+ invalid_spawn_areas[
142
+ coordx : coordx + ceil(sector[2] * width),
143
+ coordy : coordy + ceil(sector[3] * height),
144
+ ]
145
+ == 0
146
+ )
147
+ valid = jnp.nonzero(sector)
107
148
  if valid[0].shape[0] == 0:
108
149
  raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
109
- spawning_sectors.append(jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1) ))
150
+ spawning_sectors.append(
151
+ jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1))
152
+ )
110
153
  return spawning_sectors
111
154
 
112
155
 
113
156
  class Environment(SMAX):
114
-
115
157
  def __init__(self, scenario: Scenario, **kwargs):
116
- map_height, map_width = scenario.terrain_raster.building.shape
158
+ map_height, map_width = scenario.terrain.building.shape
117
159
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
118
160
  if "unit_type_pushable" in kwargs:
119
161
  self.unit_type_pushable = kwargs["unit_type_pushable"]
120
162
  del kwargs["unit_type_pushable"]
121
163
  else:
122
- self.unit_type_pushable = jnp.array([1,1,0,0,0,1])
164
+ self.unit_type_pushable = jnp.array([1, 1, 0, 0, 0, 1])
123
165
  if "reset_when_done" in kwargs:
124
166
  self.reset_when_done = kwargs["reset_when_done"]
125
167
  del kwargs["reset_when_done"]
126
168
  else:
127
169
  self.reset_when_done = True
128
170
  super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
129
- self.terrain_raster = scenario.terrain_raster
171
+ self.terrain = scenario.terrain
130
172
  self.unit_starting_sectors = scenario.unit_starting_sectors
131
173
  # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
132
174
  # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
133
175
  # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
134
176
  self.scenario = scenario
135
- self.unit_type_velocities = jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5 if "unit_type_velocities" not in kwargs else kwargs["unit_type_velocities"]
177
+ self.unit_type_velocities = (
178
+ jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15]) / 2.5
179
+ if "unit_type_velocities" not in kwargs
180
+ else kwargs["unit_type_velocities"]
181
+ )
136
182
  self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
137
183
  self.max_steps = 200
138
184
  self._push_units_away = lambda state, firmness=1: state # overwrite push units
139
- self.spawning_sectors = sectors_fn(self.unit_starting_sectors, scenario.terrain_raster.building + scenario.terrain_raster.water)
140
- self.resolution = self.terrain_raster.building.shape[0] + self.terrain_raster.building.shape[1]
141
- self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, self.resolution))
142
-
185
+ self.spawning_sectors = sectors_fn(
186
+ self.unit_starting_sectors,
187
+ scenario.terrain.building + scenario.terrain.water,
188
+ )
189
+ self.resolution = (
190
+ jnp.array(jnp.max(self.unit_type_sight_ranges), dtype=jnp.int32) * 2
191
+ )
192
+ self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, 1))
143
193
 
144
- def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
194
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
145
195
  """Environment-specific reset."""
146
196
  unit_positions = spawn_fn(rng, self.spawning_sectors)
147
197
  unit_teams = jnp.zeros((self.num_agents,))
@@ -161,24 +211,25 @@ class Environment(SMAX):
161
211
  time=0,
162
212
  terminal=False,
163
213
  unit_weapon_cooldowns=unit_weapon_cooldowns,
164
- # terrain=self.terrain_raster,
214
+ # terrain=self.terrain,
165
215
  )
166
216
  state = self._push_units_away(state) # type: ignore could be slow
167
217
  obs = self.get_obs(state)
218
+ # remove world_state from obs
168
219
  world_state = self.get_world_state(state)
169
- # obs["world_state"] = jax.lax.stop_gradient(world_state)
220
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
170
221
  return obs, state
171
222
 
172
- def step_env(self, rng, state: State, action: Array): # type: ignore
173
- obs, state, rewards, dones, infos = super().step_env(rng, state, action)
223
+ # def step_env(self, rng, state: State, action: Array): # type: ignore
224
+ # obs, state, rewards, dones, infos = super().step_env(rng, state, action)
174
225
  # delete world_state from obs
175
- obs.pop("world_state")
176
- if not self.reset_when_done:
177
- for key in dones.keys():
178
- dones[key] = False
179
- return obs, state, rewards, dones, infos
226
+ # obs.pop("world_state")
227
+ # if not self.reset_when_done:
228
+ # for key in dones.keys():
229
+ # dones[key] = False
230
+ # return obs, state, rewards, dones, infos
180
231
 
181
- def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
232
+ def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
182
233
  """Applies observation function to state."""
183
234
 
184
235
  def get_features(i, j):
@@ -188,9 +239,7 @@ class Environment(SMAX):
188
239
  # The observation is such that allies are always first
189
240
  # so for units in the second team we count in reverse.
190
241
  j = jax.lax.cond(
191
- i < self.num_allies,
192
- lambda: j,
193
- lambda: self.num_agents - j - 1,
242
+ i < self.num_allies, lambda: j, lambda: self.num_agents - j - 1
194
243
  )
195
244
  offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
196
245
  j_idx = jax.lax.cond(
@@ -205,8 +254,14 @@ class Environment(SMAX):
205
254
  < self.unit_type_sight_ranges[state.unit_types[i]]
206
255
  )
207
256
  return jax.lax.cond(
208
- visible & state.unit_alive[i] & state.unit_alive[j_idx]
209
- & self.has_line_of_sight(state.unit_positions[j_idx], state.unit_positions[i], self.terrain_raster.building + self.terrain_raster.forest),
257
+ visible
258
+ & state.unit_alive[i]
259
+ & state.unit_alive[j_idx]
260
+ & self.has_line_of_sight(
261
+ state.unit_positions[j_idx],
262
+ state.unit_positions[i],
263
+ self.terrain.building + self.terrain.forest,
264
+ ),
210
265
  lambda: features,
211
266
  lambda: empty_features,
212
267
  )
@@ -240,20 +295,20 @@ class Environment(SMAX):
240
295
  pos
241
296
  + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
242
297
  )
243
- return jnp.where(self.unit_type_pushable[unit_types][:, None], unit_positions, pos)
244
-
245
- def has_line_of_sight(self, source, target, raster_input): # this is tooooo slow TODO: make it fast
246
- # we could compute this for units in sight only using a switch
247
-
248
- cells = jnp.array(source[:, jnp.newaxis] * self.t + (1-self.t) * target[:, jnp.newaxis], dtype=jnp.int32)
249
-
250
- mask = jnp.zeros(raster_input.shape).at[cells[1, :], cells[0, :]].set(1)
298
+ return jnp.where(
299
+ self.unit_type_pushable[unit_types][:, None], unit_positions, pos
300
+ )
251
301
 
302
+ def has_line_of_sight(self, source, target, raster_input):
303
+ # suppose that the target is in sight_range of source, otherwise the line of sight might miss some cells
304
+ cells = jnp.array(
305
+ source[:, jnp.newaxis] * self.t + (1 - self.t) * target[:, jnp.newaxis],
306
+ dtype=jnp.int32,
307
+ )
308
+ mask = jnp.zeros(raster_input.shape).at[cells[0, :], cells[1, :]].set(1)
252
309
  flag = ~jnp.any(jnp.logical_and(mask, raster_input))
253
-
254
310
  return flag
255
311
 
256
-
257
312
  @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
258
313
  def _world_step( # modified version of JaxMARL's SMAX _world_step
259
314
  self,
@@ -265,13 +320,12 @@ class Environment(SMAX):
265
320
  pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
266
321
  minimum = jnp.minimum(pos, new_pos)
267
322
  maximum = jnp.maximum(pos, new_pos)
268
- mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask, 0)
269
- mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask, 0)
270
- mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask.T, 0).T
271
- mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask.T, 0).T
323
+ mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask.T, 0).T
324
+ mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask.T, 0).T
325
+ mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask, 0)
326
+ mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask, 0)
272
327
  return jnp.any(mask)
273
328
 
274
-
275
329
  def update_position(idx, vec):
276
330
  # Compute the movements slightly strangely.
277
331
  # The velocities below are for diagonal directions
@@ -287,13 +341,17 @@ class Environment(SMAX):
287
341
  )
288
342
  # avoid going out of bounds
289
343
  new_pos = jnp.maximum(
290
- jnp.minimum(new_pos, jnp.array([self.map_width-1, self.map_height-1])),
344
+ jnp.minimum(
345
+ new_pos, jnp.array([self.map_width - 1, self.map_height - 1])
346
+ ),
291
347
  jnp.zeros((2,)),
292
348
  )
293
349
 
294
350
  #######################################################################
295
351
  ############################################ avoid going into obstacles
296
- clash = raster_crossing(pos, new_pos, self.terrain_raster.building + self.terrain_raster.water)
352
+ clash = raster_crossing(
353
+ pos, new_pos, self.terrain.building + self.terrain.water
354
+ )
297
355
  new_pos = jnp.where(clash, pos, new_pos)
298
356
 
299
357
  #######################################################################
@@ -331,14 +389,11 @@ class Environment(SMAX):
331
389
  attacked_idx = jax.lax.select(
332
390
  action < self.num_movement_actions, idx, attacked_idx
333
391
  )
334
-
392
+ distance = jnp.linalg.norm(
393
+ state.unit_positions[idx] - state.unit_positions[attacked_idx]
394
+ )
335
395
  attack_valid = (
336
- (
337
- jnp.linalg.norm(
338
- state.unit_positions[idx] - state.unit_positions[attacked_idx]
339
- )
340
- < self.unit_type_attack_ranges[state.unit_types[idx]]
341
- )
396
+ (distance <= self.unit_type_attack_ranges[state.unit_types[idx]])
342
397
  & state.unit_alive[idx]
343
398
  & state.unit_alive[attacked_idx]
344
399
  )
@@ -349,21 +404,28 @@ class Environment(SMAX):
349
404
  -self.unit_type_attacks[state.unit_types[idx]],
350
405
  0.0,
351
406
  )
407
+ health_diff = jnp.where(
408
+ state.unit_types[idx] == 1,
409
+ health_diff
410
+ * distance
411
+ / self.unit_type_attack_ranges[state.unit_types[idx]],
412
+ health_diff,
413
+ )
352
414
  # design choice based on the pysc2 randomness details.
353
415
  # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
354
416
 
355
417
  #########################################################
356
418
  ############################### Add bystander health diff
357
419
 
358
- bystander_idxs = bystander_fn(attacked_idx) # TODO: use
359
- bystander_valid = (
360
- jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
361
- .astype(jnp.bool_) # type: ignore
362
- .astype(jnp.float32)
363
- )
364
- bystander_health_diff = (
365
- bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
366
- )
420
+ # bystander_idxs = bystander_fn(attacked_idx) # TODO: use
421
+ # bystander_valid = (
422
+ # jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
423
+ # .astype(jnp.bool_) # type: ignore
424
+ # .astype(jnp.float32)
425
+ # )
426
+ # bystander_health_diff = (
427
+ # bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
428
+ # )
367
429
 
368
430
  #########################################################
369
431
  #########################################################
@@ -387,26 +449,28 @@ class Environment(SMAX):
387
449
  health_diff,
388
450
  attacked_idx,
389
451
  cooldown_diff,
390
- (bystander_health_diff, bystander_idxs),
452
+ # (bystander_health_diff, bystander_idxs),
391
453
  )
392
454
 
393
455
  def perform_agent_action(idx, action, key):
394
456
  movement_action, attack_action = action
395
457
  new_pos = update_position(idx, movement_action)
396
- health_diff, attacked_idxes, cooldown_diff, (bystander) = (
397
- update_agent_health(idx, attack_action, key)
458
+ health_diff, attacked_idxes, cooldown_diff = update_agent_health(
459
+ idx, attack_action, key
398
460
  )
399
461
 
400
- return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
462
+ return new_pos, (health_diff, attacked_idxes), cooldown_diff
401
463
 
402
464
  keys = jax.random.split(key, num=self.num_agents)
403
- pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
465
+ pos, (health_diff, attacked_idxes), cooldown_diff = jax.vmap(
404
466
  perform_agent_action
405
467
  )(jnp.arange(self.num_agents), actions, keys)
406
468
 
407
469
  # units push each other
408
470
  new_pos = self._our_push_units_away(pos, state.unit_types)
409
- clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(pos, new_pos, self.terrain_raster.building + self.terrain_raster.water)
471
+ clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(
472
+ pos, new_pos, self.terrain.building + self.terrain.water
473
+ )
410
474
  pos = jax.vmap(jnp.where)(clash, pos, new_pos)
411
475
  # avoid going out of bounds
412
476
  pos = jnp.maximum(
@@ -441,8 +505,8 @@ class Environment(SMAX):
441
505
  #########################################################
442
506
  ############################ subtracting bystander health
443
507
 
444
- _, bystander_health_diff = bystander
445
- unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
508
+ # _, bystander_health_diff = bystander
509
+ # unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
446
510
 
447
511
  #########################################################
448
512
  #########################################################
@@ -456,14 +520,23 @@ class Environment(SMAX):
456
520
  )
457
521
  return state
458
522
 
523
+
459
524
  if __name__ == "__main__":
460
525
  n_envs = 4
461
526
 
462
-
463
527
  n_allies = 10
464
- scenario_kwargs = {"allies_type": 0, "n_allies": n_allies, "enemies_type": 0, "n_enemies": n_allies,
465
- "place": "Vesterbro, Copenhagen, Denmark", "size": 256, "unit_starting_sectors":
466
- [([i for i in range(n_allies)], [0.,0.45,0.1,0.1]), ([n_allies+i for i in range(n_allies)], [0.8,0.5,0.1,0.1])]}
528
+ scenario_kwargs = {
529
+ "allies_type": 0,
530
+ "n_allies": n_allies,
531
+ "enemies_type": 0,
532
+ "n_enemies": n_allies,
533
+ "place": "Vesterbro, Copenhagen, Denmark",
534
+ "size": 100,
535
+ "unit_starting_sectors": [
536
+ ([i for i in range(n_allies)], [0.0, 0.45, 0.1, 0.1]),
537
+ ([n_allies + i for i in range(n_allies)], [0.8, 0.5, 0.1, 0.1]),
538
+ ],
539
+ }
467
540
  scenario = make_scenario(**scenario_kwargs)
468
541
  env = Environment(scenario)
469
542
  rng, reset_rng = random.split(random.PRNGKey(0))
@@ -471,24 +544,18 @@ if __name__ == "__main__":
471
544
  obs, state = vmap(env.reset)(reset_key)
472
545
  state_seq = []
473
546
 
474
-
475
- from tqdm import tqdm
476
547
  import time
548
+
477
549
  step = vmap(jit(env.step))
478
550
  tic = time.time()
479
- for i in tqdm(range(10)):
551
+ for i in range(10):
480
552
  rng, act_rng, step_rng = random.split(rng, 3)
481
553
  act_key = random.split(act_rng, (len(env.agents), n_envs))
482
- print(tic - time.time())
483
554
  act = {
484
555
  a: vmap(env.action_space(a).sample)(act_key[i])
485
556
  for i, a in enumerate(env.agents)
486
557
  }
487
- print(tic - time.time())
488
558
  step_key = random.split(step_rng, n_envs)
489
- print(tic - time.time())
490
559
  state_seq.append((step_key, state, act))
491
- print(tic - time.time())
492
560
  obs, state, reward, done, infos = step(step_key, state, act)
493
- print(tic - time.time())
494
561
  tic = time.time()
parabellum/geo.py CHANGED
@@ -21,6 +21,7 @@ from typing import Tuple
21
21
  import matplotlib.pyplot as plt
22
22
  import seaborn as sns
23
23
  import os
24
+ from jax.scipy.signal import convolve
24
25
 
25
26
  # %% Types
26
27
  Coords = Tuple[float, float]
@@ -31,7 +32,20 @@ provider = cx.providers.Stadia.StamenTerrain( # type: ignore
31
32
  api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
32
33
  )
33
34
  provider["url"] = provider["url"] + "?api_key={api_key}"
34
- tags = {"building": True, "water": True, "landuse": "forest"} # "road": True}
35
+ tags = {
36
+ "building": True,
37
+ "water": True,
38
+ "highway": True,
39
+ "landuse": [
40
+ "grass",
41
+ "forest",
42
+ "flowerbed",
43
+ "greenfield",
44
+ "village_green",
45
+ "recreation_ground",
46
+ ],
47
+ "leisure": "garden",
48
+ } # "road": True}
35
49
 
36
50
 
37
51
  # %% Coordinate function
@@ -54,34 +68,47 @@ def get_bbox(place: str, buffer) -> BBox:
54
68
  def basemap_fn(bbox: BBox, gdf) -> Array:
55
69
  fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
56
70
  gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
57
- cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
71
+ cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
58
72
  bbox = gdf.total_bounds
59
73
  ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
60
74
  plt.axis("off")
61
- plt.tight_layout()
75
+ plt.tight_layout(pad=0)
62
76
  fig.canvas.draw()
63
77
  image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
64
78
  plt.close(fig)
65
79
  return image
66
80
 
67
81
 
68
- def geography_fn(place, buffer):
82
+ def geography_fn(place, buffer=400):
69
83
  bbox = get_bbox(place, buffer)
70
84
  map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
71
85
  gdf = gpd.GeoDataFrame(map_data)
72
- gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
86
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
87
+ "EPSG:3857"
88
+ )
73
89
  raster = raster_fn(gdf, shape=(buffer, buffer))
74
- basemap = basemap_fn(bbox, gdf)
75
- terrain = tps.Terrain(building=raster[0], water=raster[1], forest=raster[2], basemap=basemap)
90
+ basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
91
+ # 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
92
+ kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
93
+ trans = lambda x: jnp.rot90(x, 3)
94
+ terrain = tps.Terrain(
95
+ building=trans(raster[0]),
96
+ water=trans(
97
+ raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
98
+ ),
99
+ forest=trans(jnp.logical_or(raster[3], raster[4])),
100
+ basemap=basemap,
101
+ )
76
102
  return terrain
77
103
 
78
104
 
79
105
  def raster_fn(gdf, shape) -> Array:
80
106
  bbox = gdf.total_bounds
81
107
  t = transform.from_bounds(*bbox, *shape) # type: ignore
82
- raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in ["building", "water", "landuse"]])
108
+ raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
83
109
  return raster
84
110
 
111
+
85
112
  def feature_fn(t, feature, gdf, shape):
86
113
  if feature not in gdf.columns:
87
114
  return jnp.zeros(shape)
@@ -89,12 +116,15 @@ def feature_fn(t, feature, gdf, shape):
89
116
  raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
90
117
  return raster
91
118
 
92
- place = "Thun, Switzerland"
93
- terrain = geography_fn(place, 800)
119
+
94
120
  # %%
95
- fig, axes = plt.subplots(1, 5, figsize=(20, 20))
96
- axes[0].imshow(terrain.building, cmap="gray")
97
- axes[1].imshow(terrain.water, cmap="gray")
98
- axes[2].imshow(terrain.forest, cmap="gray")
99
- axes[3].imshow(terrain.building + terrain.water + terrain.forest)
100
- axes[4].imshow(terrain.basemap)
121
+ if __name__ == "__main__":
122
+ place = "Thun, Switzerland"
123
+ terrain = geography_fn(place, 300)
124
+
125
+ fig, axes = plt.subplots(1, 5, figsize=(20, 20))
126
+ axes[0].imshow(terrain.building, cmap="gray")
127
+ axes[1].imshow(terrain.water, cmap="gray")
128
+ axes[2].imshow(terrain.forest, cmap="gray")
129
+ axes[3].imshow(terrain.building + terrain.water + terrain.forest)
130
+ axes[4].imshow(terrain.basemap)
@@ -0,0 +1,117 @@
1
+ # %%
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from parabellum import tps
5
+
6
+
7
+ # %%
8
+ def map_raster_from_line(raster, line, size):
9
+ x0, y0, dx, dy = line
10
+ x0 = int(x0*size)
11
+ y0 = int(y0*size)
12
+ dx = int(dx*size)
13
+ dy = int(dy*size)
14
+ max_T = int(2**0.5 * size)
15
+ for t in range(max_T+1):
16
+ alpha = t/float(max_T)
17
+ x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0+dx))
18
+ y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0+dy))
19
+ if 0<=x<size and 0<=y<size:
20
+ raster[x, y] = 1
21
+ return raster
22
+
23
+
24
+ # %%
25
+ def map_raster_from_rect(raster, rect, size):
26
+ x0, y0, dx, dy = rect
27
+ x0 = int(x0*size)
28
+ y0 = int(y0*size)
29
+ dx = int(dx*size)
30
+ dy = int(dy*size)
31
+ raster[x0:x0+dx, y0:y0+dy] = 1
32
+ return raster
33
+
34
+
35
+ # %%
36
+ building_color = jnp.array([201,199,198, 255])
37
+ water_color = jnp.array([193, 237, 254, 255])
38
+ forest_color = jnp.array([197,214,185, 255])
39
+ empty_color = jnp.array([255, 255, 255, 255])
40
+
41
+ def make_terrain(terrain_args, size):
42
+ args = {}
43
+ for key, config in terrain_args.items():
44
+ raster = np.zeros((size, size))
45
+ if config is not None:
46
+ for elem in config:
47
+ if "line" in elem:
48
+ raster = map_raster_from_line(raster, elem["line"], size)
49
+ elif "rect" in elem:
50
+ raster = map_raster_from_rect(raster, elem["rect"], size)
51
+ args[key] = jnp.array(raster.T)
52
+ basemap = jnp.where(args["building"][:,:,None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size,size, 1)))
53
+ basemap = jnp.where(args["water"][:,:,None], jnp.tile(water_color, (size, size, 1)), basemap)
54
+ basemap = jnp.where(args["forest"][:,:,None], jnp.tile(forest_color, (size, size, 1)), basemap)
55
+ args["basemap"] = basemap
56
+ return tps.Terrain(**args)
57
+
58
+
59
+ # %%
60
+ db = {
61
+ "blank": {'building': None, 'water': None, 'forest': None},
62
+ "F": {'building': [{"line": [0.25, 0.33, 0.5, 0]}, {"line":[0.75, 0.33, 0., 0.25]}, {"line":[0.50, 0.33, 0., 0.25]}], 'water': None, 'forest': None},
63
+ "stronghold": {'building': [
64
+ {"line":[0.2, 0.275, 0.2, 0.]}, {"line":[0.2, 0.275, 0.0, 0.2]},
65
+ {"line":[0.4, 0.275, 0.0, 0.2]}, {"line":[0.2, 0.475, 0.2, 0.]},
66
+
67
+ {"line":[0.2, 0.525, 0.2, 0.]}, {"line": [0.2, 0.525, 0.0, 0.2]},
68
+ {"line":[0.4, 0.525, 0.0, 0.2]}, {"line": [0.2, 0.725, 0.525, 0.]},
69
+
70
+ {"line":[0.75, 0.25, 0., 0.2]}, {"line":[0.75, 0.55, 0., 0.19]},
71
+ {"line":[0.6, 0.25, 0.15, 0.]}], 'water': None, 'forest': None},
72
+ "playground": {'building': [{"line":[0.5, 0.5, 0.5, 0.]}], 'water': None, 'forest': None},
73
+ "water_park": {
74
+ 'building': [{"line":[0.5, 0.5, 0.5, 0.]}],
75
+ "water": [{"rect":[0., 0.8, 0.1, 0.05]}, {"rect": [0.2, 0.8, 0.8, 0.05]}],
76
+ "forest": [{"rect": [0., 0., 1., 0.2]}]
77
+ },
78
+ "triangle": {'building': [{"line": [0.33, 0., 0., 1.]}, {"line": [0.66, 0., 0., 1.]}], 'water': None, 'forest': None},
79
+ "u_shape": {
80
+ 'building': [],
81
+ "water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
82
+ "forest": []
83
+ },
84
+ }
85
+
86
+ # %% [raw]
87
+ # import matplotlib.pyplot as plt
88
+ # size = 50
89
+ # raster = np.zeros((size, size))
90
+ # rect = [0.2, 0.3, 0.05, 0.4]
91
+ # raster = map_raster_from_rect(raster, rect, size)
92
+ # rect = [0.4, 0.3, 0.05, 0.4]
93
+ # raster = map_raster_from_rect(raster, rect, size)
94
+ # rect = [0.2, 0.3, 0.25, 0.05]
95
+ # raster = map_raster_from_rect(raster, rect, size)
96
+ # rect = [0.2, 0.7, 0.25, 0.05]
97
+ # raster = map_raster_from_rect(raster, rect, size)
98
+ # rect = [0.6, 0.3, 0.4, 0.45]
99
+ # raster = map_raster_from_rect(raster, rect, size)
100
+ # plt.imshow(jnp.rot90(raster))
101
+
102
+ # %% [markdown]
103
+ # # Main
104
+
105
+ # %%
106
+ if __name__ == "__main__":
107
+ import matplotlib.pyplot as plt
108
+
109
+ # %%
110
+ terrain = make_terrain(db["u_shape"], size=50)
111
+
112
+ # %%
113
+ plt.imshow(jnp.rot90(terrain.basemap))
114
+
115
+ # %%
116
+
117
+ # %%
parabellum/tps.py CHANGED
@@ -1,3 +1,4 @@
1
+ # %%
1
2
  # tps.py
2
3
  # parabellum types and dataclasses
3
4
  # by: Noah Syrkis
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.3.0
3
+ Version: 0.3.3
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -38,7 +38,7 @@ Description-Content-Type: text/markdown
38
38
 
39
39
  # Parabellum
40
40
 
41
- Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
41
+ Ultra-scalable JaxMARL based warfare simulation engine.
42
42
 
43
43
  [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
44
44
 
@@ -95,3 +95,4 @@ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.i
95
95
  ## License
96
96
 
97
97
  MIT
98
+
@@ -0,0 +1,13 @@
1
+ parabellum/__init__.py,sha256=hIOLir7wgaf_HU4j8uos7PaCrofqPQcr3FcMlBsZyr8,406
2
+ parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
3
+ parabellum/env.py,sha256=yf5pSmTuFEHS4J4zKTprNTVsOni6bABP9IfWTMrO0OU,22581
4
+ parabellum/geo.py,sha256=PwEwspOppTPrHIXDZB_nGPTnVFIvDzbh2WtqzVKMUaM,4198
5
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
6
+ parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
7
+ parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
8
+ parabellum/terrain_db.py,sha256=XTKlpLAi3ZwoVw4-KS-Eh15NKsBKP-yt8v6FJGUtwdM,3960
9
+ parabellum/tps.py,sha256=of-RBdelAbNCHQZd1I22RWmZkwUEh6f161mx0X_G2tE,257
10
+ parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
11
+ parabellum-0.3.3.dist-info/METADATA,sha256=OzXtMvFmkyMwAv4d3X7YFRAhDLuiuRda2ytgsgAXDIA,2707
12
+ parabellum-0.3.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ parabellum-0.3.3.dist-info/RECORD,,
parabellum/map.py DELETED
@@ -1,95 +0,0 @@
1
- # ludens.py
2
- # script for fucking around and finding out
3
- # by: Noah Syrkis
4
-
5
-
6
- # %% Imports
7
- # import parabellum as pb
8
- import matplotlib.pyplot as plt
9
- import osmnx as ox
10
- from geopy.geocoders import Nominatim
11
- import numpy as np
12
- import contextily as cx
13
- import jax.numpy as jnp
14
- import geopandas as gpd
15
- import rasterio
16
- from rasterio import features
17
- from shapely.geometry import Point
18
- from typing import List
19
-
20
- # %% Constants
21
- geolocator = Nominatim(user_agent="parabellum")
22
- source = cx.providers.OpenStreetMap.Mapnik # type: ignore
23
-
24
-
25
- def get_raster(
26
- place: str, meters: int = 1000, tags: List[dict] | dict = {"building": True}
27
- ) -> jnp.ndarray:
28
- # look here for tags https://wiki.openstreetmap.org/wiki/Map_features
29
- def aux(place, tag):
30
- """Rasterize geometry and return as a JAX array."""
31
- place = geolocator.geocode(place) # type: ignore
32
- point = place.latitude, place.longitude # type: ignore # confusing order of lat/lon
33
- geom = ox.features_from_point(point, tags=tag, dist=meters // 2)
34
- gdf = gpd.GeoDataFrame(geom).set_crs("EPSG:4326")
35
- # crop everythin outside of the meters x meters square
36
- gdf = gdf.cx[
37
- place.longitude - meters / 2 : place.longitude + meters / 2,
38
- place.latitude - meters / 2 : place.latitude + meters / 2,
39
- ]
40
-
41
- # bounds should be meters, meters
42
- t = rasterio.transform.from_bounds(*bounds, meters, meters) # type: ignore
43
- raster = features.rasterize(
44
- gdf.geometry, out_shape=(meters, meters), transform=t
45
- )
46
- return jnp.array(raster)
47
-
48
- if isinstance(tags, dict):
49
- return aux(place, tags)
50
- else:
51
- return jnp.stack([aux(place, tag) for tag in tags])
52
-
53
-
54
- def get_basemap(
55
- place: str, size: int = 1000
56
- ) -> np.ndarray: # TODO: image is slightly off from raster. Fix this.
57
- # Create a GeoDataFrame with the center point
58
- place = geolocator.geocode(place) # type: ignore
59
- lon, lat = place.longitude, place.latitude # type: ignore
60
- gdf = gpd.GeoDataFrame(geometry=[Point(lon, lat)], crs="EPSG:4326")
61
- gdf = gdf.to_crs("EPSG:3857")
62
-
63
- # Create a buffer around the center point
64
- # buffer = gdf.buffer(size) # type: ignore
65
- buffer = gdf
66
- bounds = buffer.total_bounds # i think this is wrong, since it ignores empty space
67
- # modify bounds to include empty space
68
- bounds = (bounds[0] - size, bounds[1] - size, bounds[2] + size, bounds[3] + size)
69
-
70
- # Create a figure and axis
71
- dpi = 300
72
- fig, ax = plt.subplots(figsize=(size / dpi, size / dpi), dpi=dpi)
73
- buffer.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=0)
74
-
75
- # Calculate the zoom level for the basemap
76
-
77
- # Add the basemap to the axis
78
- cx.add_basemap(ax, source=source, zoom="auto", attribution=False)
79
-
80
- # Set the x and y limits of the axis
81
- ax.set_xlim(bounds[0], bounds[2])
82
- ax.set_ylim(bounds[1], bounds[3])
83
-
84
- # convert the image (without axis or border) to a numpy array
85
- plt.axis("off")
86
- plt.tight_layout()
87
-
88
- # remove whitespace
89
- fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
90
- fig.canvas.draw()
91
-
92
- image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) # type: ignore
93
- image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
94
- plt.close()
95
- return jnp.array(image) # type: ignore
@@ -1,13 +0,0 @@
1
- parabellum/__init__.py,sha256=vqQbvsTT_zcLThZ7fLoJ6cMAZbEeGIJDFyCkHmovfOY,392
2
- parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
3
- parabellum/env.py,sha256=VV3VK7TTkianihqJopRbY0vlRWOquu-VTrc9ep0PSTk,21304
4
- parabellum/geo.py,sha256=xkj6iJqN076tRbaG38Sq7gtwKSNzxI37msRLnpn5JV0,3561
5
- parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
6
- parabellum/map.py,sha256=9AV0PIqInXcWWojzHshy3X42Nm3ZDq0O1NG-6fQ9Wgw,3345
7
- parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
8
- parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
9
- parabellum/tps.py,sha256=3tVqo42ggE8idZn500C0X2pS9TmYndgBzlAG7Yj2Wz8,252
10
- parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
11
- parabellum-0.3.0.dist-info/METADATA,sha256=FugXwz25bAPYKlIfqFc7dGVtPupse5zHYapmqBWopE8,2740
12
- parabellum-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- parabellum-0.3.0.dist-info/RECORD,,