parabellum 0.2.26__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,18 +1,22 @@
1
- from .env import Environment, Scenario, scenarios, make_scenario, State
1
+ from .env import Environment, Scenario, make_scenario, State
2
2
  from .vis import Visualizer, Skin
3
- from .map import terrain_fn
4
3
  from .gun import bullet_fn
5
- # from .aid import aid
4
+ from . import vis
5
+ from . import map
6
+ from . import env
7
+ from . import tps
6
8
  # from .run import run
7
9
 
8
10
  __all__ = [
11
+ "env",
12
+ "map",
13
+ "vis",
14
+ "tps",
9
15
  "Environment",
10
16
  "Scenario",
11
- "scenarios",
12
17
  "make_scenario",
13
18
  "State",
14
19
  "Visualizer",
15
20
  "Skin",
16
- "terrain_fn",
17
21
  "bullet_fn",
18
22
  ]
parabellum/aid.py CHANGED
@@ -3,3 +3,25 @@
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # imports
6
+ import os
7
+ from collections import namedtuple
8
+ from typing import Tuple
9
+ import cartopy.crs as ccrs
10
+
11
+ # types
12
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
13
+
14
+
15
+ # coordinate function
16
+ def to_mercator(bbox: BBox) -> BBox:
17
+ proj = ccrs.Mercator()
18
+ west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
19
+ east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
20
+ return BBox(north=north, south=south, east=east, west=west)
21
+
22
+
23
+ def to_platecarree(bbox: BBox) -> BBox:
24
+ proj = ccrs.PlateCarree()
25
+ west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
26
+ east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
27
+ return BBox(north=north, south=south, east=east, west=west)
parabellum/env.py CHANGED
@@ -2,15 +2,14 @@
2
2
 
3
3
  import jax.numpy as jnp
4
4
  import jax
5
- import numpy as np
6
- from jax import random, Array
7
- from jax import jit
5
+ from jax import random, Array, vmap, jit
8
6
  from flax.struct import dataclass
9
7
  import chex
10
- from jax import vmap
11
8
  from jaxmarl.environments.smax.smax_env import SMAX
9
+
12
10
  from typing import Tuple, Dict, cast
13
11
  from functools import partial
12
+ from parabellum import tps, geo
14
13
 
15
14
 
16
15
  @dataclass
@@ -18,8 +17,8 @@ class Scenario:
18
17
  """Parabellum scenario"""
19
18
 
20
19
  place: str
21
- terrain_raster: jnp.ndarray
22
- unit_starting_sectors: jnp.ndarray
20
+ terrain_raster: tps.Terrain
21
+ unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
23
22
  unit_types: chex.Array
24
23
  num_allies: int
25
24
  num_enemies: int
@@ -27,9 +26,11 @@ class Scenario:
27
26
  smacv2_position_generation: bool = False
28
27
  smacv2_unit_type_generation: bool = False
29
28
 
29
+
30
30
  @dataclass
31
31
  class State:
32
- unit_positions: Array
32
+ # terrain: Array
33
+ unit_positions: Array # fsfds
33
34
  unit_alive: Array
34
35
  unit_teams: Array
35
36
  unit_health: Array
@@ -41,71 +42,89 @@ class State:
41
42
  terminal: bool
42
43
 
43
44
 
44
- # default scenario
45
- scenarios = {
46
- "default": Scenario(
47
- "Identity Town",
48
- jnp.eye(64, dtype=jnp.uint8),
49
- jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
50
- jnp.zeros((19,), dtype=jnp.uint8),
51
- 9,
52
- 10,
53
- )
54
- }
55
-
56
45
 
57
- def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_allies, enemies_type, n_enemies):
46
+ def make_scenario(
47
+ place,
48
+ size,
49
+ unit_starting_sectors,
50
+ allies_type,
51
+ n_allies,
52
+ enemies_type,
53
+ n_enemies,
54
+ ):
55
+ terrain = geo.geography_fn(place, size)
56
+ if type(unit_starting_sectors) == list:
57
+ default_sector = [0, 0, size, size] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
58
+ correct_unit_starting_sectors = []
59
+ for i in range(n_allies+n_enemies):
60
+ selected_sector = None
61
+ for unit_ids, sector in unit_starting_sectors:
62
+ if i in unit_ids:
63
+ selected_sector = sector
64
+ if selected_sector is None:
65
+ selected_sector = default_sector
66
+ correct_unit_starting_sectors.append(selected_sector)
67
+ unit_starting_sectors = correct_unit_starting_sectors
58
68
  if type(allies_type) == int:
59
69
  allies = [allies_type] * n_allies
60
70
  else:
61
- assert(len(allies_type) == n_allies)
71
+ assert len(allies_type) == n_allies
62
72
  allies = allies_type
63
73
 
64
74
  if type(enemies_type) == int:
65
75
  enemies = [enemies_type] * n_enemies
66
76
  else:
67
- assert(len(enemies_type) == n_enemies)
77
+ assert len(enemies_type) == n_enemies
68
78
  enemies = enemies_type
69
79
  unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
70
- return Scenario(place, terrain_raster, unit_starting_sectors, unit_types, n_allies, n_enemies)
80
+ return Scenario(
81
+ place, terrain, unit_starting_sectors, unit_types, n_allies, n_enemies # type: ignore
82
+ )
71
83
 
72
84
 
73
- def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
85
+ def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
74
86
  """Spawns n agents on a map."""
75
- rng, key_start, key_noise = random.split(rng, 3)
76
- noise = random.uniform(key_noise, (n, 2)) * 0.5
77
-
78
- # select n random (x, y)-coords where sector == True
79
- idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
80
- coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
81
-
82
- return coords + noise + offset
83
-
84
-
85
- def sector_fn(terrain: jnp.ndarray, sector_id: int):
86
- """return sector slice of terrain"""
87
- width, height = terrain.shape
88
- coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
89
- sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
90
- offset = jnp.array([coordx, coordy])
91
- # sector is jnp.nonzero
92
- return jnp.nonzero(sector), offset
93
-
94
-
95
- def sector_fn(terrain: jnp.ndarray, sector: jnp.ndarray):
96
- """return sector slice of terrain"""
97
- width, height = terrain.shape
98
- coordx, coordy = int(sector[0] * width), int(sector[1] * height)
99
- sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
100
- offset = jnp.array([coordx, coordy])
101
- # sector is jnp.nonzero
102
- return jnp.nonzero(sector.T), offset
87
+ spawn_positions = []
88
+ for sector in units_spawning_sectors:
89
+ rng, key_start, key_noise = random.split(rng, 3)
90
+ noise = random.uniform(key_noise, (2,)) * 0.5
91
+ idx = random.choice(key_start, sector[0].shape[0])
92
+ coord = jnp.array([sector[0][idx], sector[1][idx]])
93
+ spawn_positions.append(coord + noise)
94
+ return jnp.array(spawn_positions, dtype=jnp.float32)
95
+
96
+
97
+ def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
98
+ """
99
+ sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
100
+ """
101
+ width, height = invalid_spawn_areas.shape
102
+ spawning_sectors = []
103
+ for sector in sectors:
104
+ coordx, coordy = jnp.array(sector[0] * width, dtype=jnp.int32), jnp.array(sector[1] * height, dtype=jnp.int32)
105
+ sector = (invalid_spawn_areas[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0)
106
+ valid = jnp.nonzero(sector.T)
107
+ if valid[0].shape[0] == 0:
108
+ raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
109
+ spawning_sectors.append(jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1) ))
110
+ return spawning_sectors
103
111
 
104
112
 
105
113
  class Environment(SMAX):
114
+
106
115
  def __init__(self, scenario: Scenario, **kwargs):
107
- map_height, map_width = scenario.terrain_raster.shape
116
+ map_height, map_width = scenario.terrain_raster.building.shape
108
117
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
118
+ if "unit_type_pushable" in kwargs:
119
+ self.unit_type_pushable = kwargs["unit_type_pushable"]
120
+ del kwargs["unit_type_pushable"]
121
+ else:
122
+ self.unit_type_pushable = jnp.array([1,1,0,0,0,1])
123
+ if "reset_when_done" in kwargs:
124
+ self.reset_when_done = kwargs["reset_when_done"]
125
+ del kwargs["reset_when_done"]
126
+ else:
127
+ self.reset_when_done = True
109
128
  super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
110
129
  self.terrain_raster = scenario.terrain_raster
111
130
  self.unit_starting_sectors = scenario.unit_starting_sectors
@@ -113,21 +132,18 @@ class Environment(SMAX):
113
132
  # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
114
133
  # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
115
134
  self.scenario = scenario
116
- self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
135
+ self.unit_type_velocities = jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5 if "unit_type_velocities" not in kwargs else kwargs["unit_type_velocities"]
117
136
  self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
118
137
  self.max_steps = 200
119
- self._push_units_away = lambda state, firmness = 1: state # overwrite push units
120
- self.team0_sector, self.team0_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[0]) # sector_fn(self.terrain_raster, 0)
121
- self.team1_sector, self.team1_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[1]) # sector_fn(self.terrain_raster, 24)
138
+ self._push_units_away = lambda state, firmness=1: state # overwrite push units
139
+ self.spawning_sectors = sectors_fn(self.unit_starting_sectors, scenario.terrain_raster.building + scenario.terrain_raster.water)
140
+ self.resolution = self.terrain_raster.building.shape[0] + self.terrain_raster.building.shape[1]
141
+ self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, self.resolution))
122
142
 
123
143
 
124
- @partial(jax.jit, static_argnums=(0,))
125
- def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
144
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
126
145
  """Environment-specific reset."""
127
- ally_key, enemy_key = jax.random.split(rng)
128
- team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
129
- team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
130
- unit_positions = jnp.concatenate([team_0_start, team_1_start])
146
+ unit_positions = spawn_fn(rng, self.spawning_sectors)
131
147
  unit_teams = jnp.zeros((self.num_agents,))
132
148
  unit_teams = unit_teams.at[self.num_allies :].set(1)
133
149
  unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
@@ -145,19 +161,67 @@ class Environment(SMAX):
145
161
  time=0,
146
162
  terminal=False,
147
163
  unit_weapon_cooldowns=unit_weapon_cooldowns,
164
+ # terrain=self.terrain_raster,
148
165
  )
149
- state = self._push_units_away(state) # type: ignore
166
+ state = self._push_units_away(state) # type: ignore could be slow
150
167
  obs = self.get_obs(state)
151
168
  world_state = self.get_world_state(state)
152
169
  # obs["world_state"] = jax.lax.stop_gradient(world_state)
153
170
  return obs, state
154
171
 
155
- def step_env(self, rng, state: State, action: Array):
172
+ def step_env(self, rng, state: State, action: Array): # type: ignore
156
173
  obs, state, rewards, dones, infos = super().step_env(rng, state, action)
157
174
  # delete world_state from obs
158
175
  obs.pop("world_state")
176
+ if not self.reset_when_done:
177
+ for key in dones.keys():
178
+ dones[key] = False
159
179
  return obs, state, rewards, dones, infos
160
180
 
181
+ def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
182
+ """Applies observation function to state."""
183
+
184
+ def get_features(i, j):
185
+ """Get features of unit j as seen from unit i"""
186
+ # Can just keep them symmetrical for now.
187
+ # j here means 'the jth unit that is not i'
188
+ # The observation is such that allies are always first
189
+ # so for units in the second team we count in reverse.
190
+ j = jax.lax.cond(
191
+ i < self.num_allies,
192
+ lambda: j,
193
+ lambda: self.num_agents - j - 1,
194
+ )
195
+ offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
196
+ j_idx = jax.lax.cond(
197
+ ((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
198
+ lambda: j,
199
+ lambda: j + offset,
200
+ )
201
+ empty_features = jnp.zeros(shape=(len(self.unit_features),))
202
+ features = self._observe_features(state, i, j_idx)
203
+ visible = (
204
+ jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
205
+ < self.unit_type_sight_ranges[state.unit_types[i]]
206
+ )
207
+ return jax.lax.cond(
208
+ visible & state.unit_alive[i] & state.unit_alive[j_idx]
209
+ & self.has_line_of_sight(state.unit_positions[j_idx], state.unit_positions[i], self.terrain_raster.building + self.terrain_raster.forest),
210
+ lambda: features,
211
+ lambda: empty_features,
212
+ )
213
+
214
+ get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
215
+ get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
216
+ other_unit_obs = get_all_features(
217
+ jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
218
+ )
219
+ other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
220
+ get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
221
+ own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
222
+ obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
223
+ return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
224
+
161
225
  def _our_push_units_away(
162
226
  self, pos, unit_types, firmness: float = 1.0
163
227
  ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
@@ -176,7 +240,19 @@ class Environment(SMAX):
176
240
  pos
177
241
  + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
178
242
  )
179
- return unit_positions
243
+ return jnp.where(self.unit_type_pushable[unit_types][:, None], unit_positions, pos)
244
+
245
+ def has_line_of_sight(self, source, target, raster_input): # this is tooooo slow TODO: make it fast
246
+ # we could compute this for units in sight only using a switch
247
+
248
+ cells = jnp.array(source[:, jnp.newaxis] * self.t + (1-self.t) * target[:, jnp.newaxis], dtype=jnp.int32)
249
+
250
+ mask = jnp.zeros(raster_input.shape).at[cells[1, :], cells[0, :]].set(1)
251
+
252
+ flag = ~jnp.any(jnp.logical_and(mask, raster_input))
253
+
254
+ return flag
255
+
180
256
 
181
257
  @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
182
258
  def _world_step( # modified version of JaxMARL's SMAX _world_step
@@ -185,24 +261,16 @@ class Environment(SMAX):
185
261
  state: State,
186
262
  actions: Tuple[chex.Array, chex.Array],
187
263
  ) -> State:
188
- @partial(jax.vmap, in_axes=(None, None, 0, 0))
189
- def intersect_fn(pos, new_pos, obs, obs_end):
190
- d1 = jnp.cross(obs - pos, new_pos - pos)
191
- d2 = jnp.cross(obs_end - pos, new_pos - pos)
192
- d3 = jnp.cross(pos - obs, obs_end - obs)
193
- d4 = jnp.cross(new_pos - obs, obs_end - obs)
194
- return (d1 * d2 <= 0) & (d3 * d4 <= 0)
195
-
196
- def raster_crossing(pos, new_pos):
264
+ def raster_crossing(pos, new_pos, mask: jnp.ndarray):
197
265
  pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
198
- raster = jnp.copy(self.terrain_raster)
199
266
  minimum = jnp.minimum(pos, new_pos)
200
267
  maximum = jnp.maximum(pos, new_pos)
201
- raster = jnp.where(jnp.arange(raster.shape[0]) >= minimum[0], raster, 0)
202
- raster = jnp.where(jnp.arange(raster.shape[0]) <= maximum[0], raster, 0)
203
- raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
204
- raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
205
- return jnp.any(raster)
268
+ mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask, 0)
269
+ mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask, 0)
270
+ mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask.T, 0).T
271
+ mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask.T, 0).T
272
+ return jnp.any(mask)
273
+
206
274
 
207
275
  def update_position(idx, vec):
208
276
  # Compute the movements slightly strangely.
@@ -219,13 +287,13 @@ class Environment(SMAX):
219
287
  )
220
288
  # avoid going out of bounds
221
289
  new_pos = jnp.maximum(
222
- jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
290
+ jnp.minimum(new_pos, jnp.array([self.map_width-1, self.map_height-1])),
223
291
  jnp.zeros((2,)),
224
292
  )
225
293
 
226
294
  #######################################################################
227
295
  ############################################ avoid going into obstacles
228
- clash = raster_crossing(pos, new_pos)
296
+ clash = raster_crossing(pos, new_pos, self.terrain_raster.building + self.terrain_raster.water)
229
297
  new_pos = jnp.where(clash, pos, new_pos)
230
298
 
231
299
  #######################################################################
@@ -290,7 +358,7 @@ class Environment(SMAX):
290
358
  bystander_idxs = bystander_fn(attacked_idx) # TODO: use
291
359
  bystander_valid = (
292
360
  jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
293
- .astype(jnp.bool_) # type: ignore
361
+ .astype(jnp.bool_) # type: ignore
294
362
  .astype(jnp.float32)
295
363
  )
296
364
  bystander_health_diff = (
@@ -338,10 +406,13 @@ class Environment(SMAX):
338
406
 
339
407
  # units push each other
340
408
  new_pos = self._our_push_units_away(pos, state.unit_types)
341
- clash = jax.vmap(raster_crossing)(pos, new_pos)
409
+ clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(pos, new_pos, self.terrain_raster.building + self.terrain_raster.water)
342
410
  pos = jax.vmap(jnp.where)(clash, pos, new_pos)
343
411
  # avoid going out of bounds
344
- pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
412
+ pos = jnp.maximum(
413
+ jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
414
+ jnp.zeros((2,)),
415
+ )
345
416
 
346
417
  # Multiple enemies can attack the same unit.
347
418
  # We have `(health_diff, attacked_idx)` pairs.
@@ -385,26 +456,39 @@ class Environment(SMAX):
385
456
  )
386
457
  return state
387
458
 
388
-
389
459
  if __name__ == "__main__":
390
460
  n_envs = 4
391
461
 
392
- env = Environment(scenarios["default"])
462
+
463
+ n_allies = 10
464
+ scenario_kwargs = {"allies_type": 0, "n_allies": n_allies, "enemies_type": 0, "n_enemies": n_allies,
465
+ "place": "Vesterbro, Copenhagen, Denmark", "size": 256, "unit_starting_sectors":
466
+ [([i for i in range(n_allies)], [0.,0.45,0.1,0.1]), ([n_allies+i for i in range(n_allies)], [0.8,0.5,0.1,0.1])]}
467
+ scenario = make_scenario(**scenario_kwargs)
468
+ env = Environment(scenario)
393
469
  rng, reset_rng = random.split(random.PRNGKey(0))
394
470
  reset_key = random.split(reset_rng, n_envs)
395
471
  obs, state = vmap(env.reset)(reset_key)
396
472
  state_seq = []
397
473
 
398
- print(state.unit_positions)
399
- exit()
400
474
 
401
- for i in range(10):
475
+ from tqdm import tqdm
476
+ import time
477
+ step = vmap(jit(env.step))
478
+ tic = time.time()
479
+ for i in tqdm(range(10)):
402
480
  rng, act_rng, step_rng = random.split(rng, 3)
403
481
  act_key = random.split(act_rng, (len(env.agents), n_envs))
482
+ print(tic - time.time())
404
483
  act = {
405
484
  a: vmap(env.action_space(a).sample)(act_key[i])
406
485
  for i, a in enumerate(env.agents)
407
486
  }
487
+ print(tic - time.time())
408
488
  step_key = random.split(step_rng, n_envs)
489
+ print(tic - time.time())
409
490
  state_seq.append((step_key, state, act))
410
- obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
491
+ print(tic - time.time())
492
+ obs, state, reward, done, infos = step(step_key, state, act)
493
+ print(tic - time.time())
494
+ tic = time.time()
parabellum/geo.py ADDED
@@ -0,0 +1,100 @@
1
+ # %% geo.py
2
+ # script for geospatial level generation
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from parabellum import tps
7
+ import rasterio
8
+ from rasterio import features, transform
9
+ from geopy.geocoders import Nominatim
10
+ from geopy.distance import distance
11
+ import contextily as cx
12
+ import jax.numpy as jnp
13
+ import cartopy.crs as ccrs
14
+ from jaxtyping import Array
15
+ import numpy as np
16
+ from shapely import box
17
+ import osmnx as ox
18
+ import geopandas as gpd
19
+ from collections import namedtuple
20
+ from typing import Tuple
21
+ import matplotlib.pyplot as plt
22
+ import seaborn as sns
23
+ import os
24
+
25
+ # %% Types
26
+ Coords = Tuple[float, float]
27
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
28
+
29
+ # %% Constants
30
+ provider = cx.providers.Stadia.StamenTerrain( # type: ignore
31
+ api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
32
+ )
33
+ provider["url"] = provider["url"] + "?api_key={api_key}"
34
+ tags = {"building": True, "water": True, "landuse": "forest"} # "road": True}
35
+
36
+
37
+ # %% Coordinate function
38
+ def get_coordinates(place: str) -> Coords:
39
+ geolocator = Nominatim(user_agent="parabellum")
40
+ point = geolocator.geocode(place)
41
+ return point.latitude, point.longitude # type: ignore
42
+
43
+
44
+ def get_bbox(place: str, buffer) -> BBox:
45
+ """Get bounding box from place name in crs 4326."""
46
+ coords = get_coordinates(place)
47
+ north = distance(meters=buffer).destination(coords, bearing=0).latitude
48
+ south = distance(meters=buffer).destination(coords, bearing=180).latitude
49
+ east = distance(meters=buffer).destination(coords, bearing=90).longitude
50
+ west = distance(meters=buffer).destination(coords, bearing=270).longitude
51
+ return BBox(north, south, east, west)
52
+
53
+
54
+ def basemap_fn(bbox: BBox, gdf) -> Array:
55
+ fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
56
+ gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
57
+ cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
58
+ bbox = gdf.total_bounds
59
+ ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
60
+ plt.axis("off")
61
+ plt.tight_layout()
62
+ fig.canvas.draw()
63
+ image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
64
+ plt.close(fig)
65
+ return image
66
+
67
+
68
+ def geography_fn(place, buffer):
69
+ bbox = get_bbox(place, buffer)
70
+ map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
71
+ gdf = gpd.GeoDataFrame(map_data)
72
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
73
+ raster = raster_fn(gdf, shape=(buffer, buffer))
74
+ basemap = basemap_fn(bbox, gdf)
75
+ terrain = tps.Terrain(building=raster[0], water=raster[1], forest=raster[2], basemap=basemap)
76
+ return terrain
77
+
78
+
79
+ def raster_fn(gdf, shape) -> Array:
80
+ bbox = gdf.total_bounds
81
+ t = transform.from_bounds(*bbox, *shape) # type: ignore
82
+ raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in ["building", "water", "landuse"]])
83
+ return raster
84
+
85
+ def feature_fn(t, feature, gdf, shape):
86
+ if feature not in gdf.columns:
87
+ return jnp.zeros(shape)
88
+ gdf = gdf[~gdf[feature].isna()]
89
+ raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
90
+ return raster
91
+
92
+ place = "Thun, Switzerland"
93
+ terrain = geography_fn(place, 800)
94
+ # %%
95
+ fig, axes = plt.subplots(1, 5, figsize=(20, 20))
96
+ axes[0].imshow(terrain.building, cmap="gray")
97
+ axes[1].imshow(terrain.water, cmap="gray")
98
+ axes[2].imshow(terrain.forest, cmap="gray")
99
+ axes[3].imshow(terrain.building + terrain.water + terrain.forest)
100
+ axes[4].imshow(terrain.basemap)
parabellum/map.py CHANGED
@@ -1,100 +1,95 @@
1
- # map.py
2
- # parabellum map functions
1
+ # ludens.py
2
+ # script for fucking around and finding out
3
3
  # by: Noah Syrkis
4
4
 
5
- # imports
6
- import jax.numpy as jnp
7
- from geopy.geocoders import Nominatim
8
- import geopandas as gpd
5
+
6
+ # %% Imports
7
+ # import parabellum as pb
8
+ import matplotlib.pyplot as plt
9
9
  import osmnx as ox
10
+ from geopy.geocoders import Nominatim
11
+ import numpy as np
10
12
  import contextily as cx
11
- import matplotlib.pyplot as plt
13
+ import jax.numpy as jnp
14
+ import geopandas as gpd
15
+ import rasterio
12
16
  from rasterio import features
13
- import rasterio.transform
14
- from typing import Optional, Tuple
15
- from geopy.location import Location
16
17
  from shapely.geometry import Point
17
- import os
18
- import pickle
18
+ from typing import List
19
19
 
20
- # constants
20
+ # %% Constants
21
21
  geolocator = Nominatim(user_agent="parabellum")
22
- BUILDING_TAGS = {"building": True}
23
-
24
- def get_location(place: str) -> Tuple[float, float]:
25
- """Get coordinates for a given place."""
26
- coords: Optional[Location] = geolocator.geocode(place) # type: ignore
27
- if coords is None:
28
- raise ValueError(f"Could not geocode the place: {place}")
29
- return (coords.latitude, coords.longitude)
30
-
31
- def get_building_geometry(point: Tuple[float, float], size: int) -> gpd.GeoDataFrame:
32
- """Get building geometry for a given point and size."""
33
- geometry = ox.features_from_point(point, tags=BUILDING_TAGS, dist=size // 2)
34
- return gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
35
-
36
- def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
37
- """Rasterize geometry and return as a JAX array."""
38
- w, s, e, n = gdf.total_bounds
39
- transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
40
- raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
41
- return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
42
-
43
- # +
44
- def get_from_cache(place, size):
45
- if os.path.exists("./cache"):
46
- name = str(hash((place, size))) + ".pk"
47
- if os.path.exists("./cache/" + name):
48
- with open("./cache/" + name, "rb") as f:
49
- (mask, base) = pickle.load(f)
50
- return (mask, base.astype(jnp.int64))
51
- return (None, None)
52
-
53
- def save_in_cache(place, size, mask, base):
54
- if not os.path.exists("./cache"):
55
- os.makedirs("./cache")
56
- name = str(hash((place, size))) + ".pk"
57
- with open("./cache/" + name, "wb") as f:
58
- pickle.dump((mask, base), f)
59
-
60
- def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
61
- """Returns a rasterized map of buildings for a given location."""
62
- if with_cache:
63
- mask, base = get_from_cache(place, size)
64
- if not with_cache or mask is None:
65
- point = get_location(place)
66
- gdf = get_building_geometry(point, size)
67
- mask = rasterize_geometry(gdf, size)
68
- base = get_basemap(place, size)
69
- if with_cache:
70
- save_in_cache(place, size, mask, base)
71
- return mask, base
72
-
73
-
74
- # -
75
-
76
- def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
77
- """Returns a basemap for a given place as a JAX array."""
78
- point = get_location(place)
79
- gdf = get_building_geometry(point, size)
80
- basemap, _ = cx.bounds2img(*gdf.total_bounds, ll=True)
81
- # get the middle size x size square
82
- basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
83
- (basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
84
- return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
85
-
86
-
87
- if __name__ == "__main__":
88
- place = "Cauvicourt, 14190, France"
89
- mask, base = terrain_fn(place, 500)
90
-
91
- fig, ax = plt.subplots(1, 3, figsize=(15, 5))
92
- ax[0].imshow(jnp.flip(mask,0)) # type: ignore
93
- ax[1].imshow(base) # type: ignore
94
- ax[2].imshow(base) # type: ignore
95
- ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
96
- plt.show()
97
-
98
-
99
-
100
-
22
+ source = cx.providers.OpenStreetMap.Mapnik # type: ignore
23
+
24
+
25
+ def get_raster(
26
+ place: str, meters: int = 1000, tags: List[dict] | dict = {"building": True}
27
+ ) -> jnp.ndarray:
28
+ # look here for tags https://wiki.openstreetmap.org/wiki/Map_features
29
+ def aux(place, tag):
30
+ """Rasterize geometry and return as a JAX array."""
31
+ place = geolocator.geocode(place) # type: ignore
32
+ point = place.latitude, place.longitude # type: ignore # confusing order of lat/lon
33
+ geom = ox.features_from_point(point, tags=tag, dist=meters // 2)
34
+ gdf = gpd.GeoDataFrame(geom).set_crs("EPSG:4326")
35
+ # crop everythin outside of the meters x meters square
36
+ gdf = gdf.cx[
37
+ place.longitude - meters / 2 : place.longitude + meters / 2,
38
+ place.latitude - meters / 2 : place.latitude + meters / 2,
39
+ ]
40
+
41
+ # bounds should be meters, meters
42
+ t = rasterio.transform.from_bounds(*bounds, meters, meters) # type: ignore
43
+ raster = features.rasterize(
44
+ gdf.geometry, out_shape=(meters, meters), transform=t
45
+ )
46
+ return jnp.array(raster)
47
+
48
+ if isinstance(tags, dict):
49
+ return aux(place, tags)
50
+ else:
51
+ return jnp.stack([aux(place, tag) for tag in tags])
52
+
53
+
54
+ def get_basemap(
55
+ place: str, size: int = 1000
56
+ ) -> np.ndarray: # TODO: image is slightly off from raster. Fix this.
57
+ # Create a GeoDataFrame with the center point
58
+ place = geolocator.geocode(place) # type: ignore
59
+ lon, lat = place.longitude, place.latitude # type: ignore
60
+ gdf = gpd.GeoDataFrame(geometry=[Point(lon, lat)], crs="EPSG:4326")
61
+ gdf = gdf.to_crs("EPSG:3857")
62
+
63
+ # Create a buffer around the center point
64
+ # buffer = gdf.buffer(size) # type: ignore
65
+ buffer = gdf
66
+ bounds = buffer.total_bounds # i think this is wrong, since it ignores empty space
67
+ # modify bounds to include empty space
68
+ bounds = (bounds[0] - size, bounds[1] - size, bounds[2] + size, bounds[3] + size)
69
+
70
+ # Create a figure and axis
71
+ dpi = 300
72
+ fig, ax = plt.subplots(figsize=(size / dpi, size / dpi), dpi=dpi)
73
+ buffer.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=0)
74
+
75
+ # Calculate the zoom level for the basemap
76
+
77
+ # Add the basemap to the axis
78
+ cx.add_basemap(ax, source=source, zoom="auto", attribution=False)
79
+
80
+ # Set the x and y limits of the axis
81
+ ax.set_xlim(bounds[0], bounds[2])
82
+ ax.set_ylim(bounds[1], bounds[3])
83
+
84
+ # convert the image (without axis or border) to a numpy array
85
+ plt.axis("off")
86
+ plt.tight_layout()
87
+
88
+ # remove whitespace
89
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
90
+ fig.canvas.draw()
91
+
92
+ image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) # type: ignore
93
+ image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
94
+ plt.close()
95
+ return jnp.array(image) # type: ignore
parabellum/pcg.py ADDED
@@ -0,0 +1,53 @@
1
+ # pcg.py
2
+ # procedural content generation
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from jax import random, vmap
7
+ import jax.numpy as jnp
8
+ import matplotlib.pyplot as plt
9
+ from functools import partial
10
+
11
+
12
+ # %% Functions
13
+ seed = 0
14
+ n = 100
15
+ rng = random.PRNGKey(seed)
16
+ Y = random.uniform(rng, (n,))
17
+
18
+
19
+ def g(t):
20
+ return (1 - jnp.cos(jnp.pi * t)) / 2
21
+
22
+
23
+ def lerp(a, b, t):
24
+ t -= jnp.floor(t) # the fractional part of t
25
+ return (1 - t) * a + t * b
26
+
27
+
28
+ def cerp(a, b, t):
29
+ t -= jnp.floor(t) # the fractional part of t
30
+ return g(1 - t) * a + g(t) * b
31
+
32
+
33
+ def body_fn(x):
34
+ i = jnp.floor(x).astype(jnp.uint8)
35
+ return cerp(Y[i], Y[i + 1], x)
36
+
37
+
38
+ @partial(vmap, in_axes=(None, 0, None))
39
+ def noise_fn(y, t, n):
40
+ return y[t % n]
41
+
42
+
43
+ @vmap
44
+ def perlin_fn(t):
45
+ return noise_fn(Y, t * jnp.arange(n * 3), n)
46
+
47
+
48
+ xs = jnp.linspace(0, 1, 1000)
49
+ noise = perlin_fn(2 ** jnp.arange(3)).sum(0)
50
+
51
+ fig, ax = plt.subplots(figsize=(20, 4), dpi=100)
52
+ ax.set_ylim(0, 1)
53
+ ax.plot(noise / noise.max())
parabellum/run.py CHANGED
@@ -21,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
21
21
 
22
22
 
23
23
  # types
24
- State = jaxmarl.environments.smax.smax_env.State # type: ignore
24
+ State = jaxmarl.environments.smax.smax_env.State # type: ignore
25
25
  Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
26
26
  StateSeq = List[Tuple[jnp.ndarray, State, Action]]
27
27
 
parabellum/tps.py ADDED
@@ -0,0 +1,16 @@
1
+ # tps.py
2
+ # parabellum types and dataclasses
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from chex import dataclass
7
+ from jaxtyping import Array
8
+
9
+
10
+ # %% Dataclasses
11
+ @dataclass
12
+ class Terrain:
13
+ building: Array
14
+ water: Array
15
+ forest: Array
16
+ basemap: Array
parabellum/vis.py CHANGED
@@ -3,26 +3,19 @@ Visualizer for the Parabellum environment
3
3
  """
4
4
 
5
5
  # Standard library imports
6
- from functools import partial
7
- from typing import Optional, List, Tuple
6
+ from typing import Optional, Tuple
8
7
  import cv2
9
- from PIL import Image
10
8
 
11
9
  # JAX and JAX-related imports
12
10
  import jax
13
11
  from chex import dataclass
14
- import chex
15
- from jax import vmap, tree_util, Array, jit
12
+ from jax import vmap, Array
16
13
  import jax.numpy as jnp
17
- from jaxmarl.environments.multi_agent_env import MultiAgentEnv
18
- from jaxmarl.environments.smax import SMAX
19
14
  from jaxmarl.viz.visualizer import SMAXVisualizer
20
15
 
21
16
  # Third-party imports
22
17
  import numpy as np
23
18
  import pygame
24
- import cv2
25
- from tqdm import tqdm
26
19
 
27
20
  # Local imports
28
21
  import parabellum as pb
@@ -35,12 +28,12 @@ class Skin:
35
28
  maskmap: Array # maskmap of buildings
36
29
  bg: Tuple[int, int, int] = (255, 255, 255)
37
30
  fg: Tuple[int, int, int] = (0, 0, 0)
38
- ally: Tuple[int, int, int] = (0, 255, 0)
39
- enemy: Tuple[int, int, int] = (255, 0, 0)
40
- pad: int = 100
41
- size: int = 1000 # excluding padding
31
+ ally: Tuple[int, int, int] = (0, 255, 0)
32
+ enemy: Tuple[int, int, int] = (255, 0, 0)
33
+ pad: int = 100
34
+ size: int = 1000 # excluding padding
42
35
  fps: int = 24
43
- vis_size: int = 1000 # size of the map in Vis (exluding padding)
36
+ vis_size: int = 1000 # size of the map in Vis (exluding padding)
44
37
  scale: Optional[float] = None
45
38
 
46
39
 
@@ -57,10 +50,23 @@ class Visualizer(SMAXVisualizer):
57
50
  self.env = env
58
51
 
59
52
  def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
60
- expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
61
- state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
62
- for idx, (state_seq, action_seq) in enumerate(zip(state_seq_seq, action_seq_seq)):
63
- animate_fn(self.env, self.skin, self.image, state_seq, action_seq, f"{save_fname}_{idx}.mp4")
53
+ expanded_state_seq, expanded_action_seq = expand_fn(
54
+ self.env, self.state_seq, self.action_seq
55
+ )
56
+ state_seq_seq, action_seq_seq = unbatch_fn(
57
+ expanded_state_seq, expanded_action_seq
58
+ )
59
+ for idx, (state_seq, action_seq) in enumerate(
60
+ zip(state_seq_seq, action_seq_seq)
61
+ ):
62
+ animate_fn(
63
+ self.env,
64
+ self.skin,
65
+ self.image,
66
+ state_seq,
67
+ action_seq,
68
+ f"{save_fname}_{idx}.mp4",
69
+ )
64
70
 
65
71
 
66
72
  # functions
@@ -70,22 +76,29 @@ def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
70
76
  for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
71
77
  frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
72
78
  # use cv2 to write frames to video
73
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
74
- out = cv2.VideoWriter(save_fname, fourcc, skin.fps, (skin.size + skin.pad * 2, skin.size + skin.pad * 2))
79
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
80
+ out = cv2.VideoWriter(
81
+ save_fname,
82
+ fourcc,
83
+ skin.fps,
84
+ (skin.size + skin.pad * 2, skin.size + skin.pad * 2),
85
+ )
75
86
  for frame in frames:
76
87
  out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
77
88
  out.release()
78
89
  pygame.quit()
79
90
 
80
91
 
81
- def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> pygame.Surface:
92
+ def init_frame(
93
+ env, skin, image, state: pb.State, action: Array, idx: int
94
+ ) -> pygame.Surface:
82
95
  dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
83
96
  frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
84
97
  return frame
85
98
 
86
99
 
87
100
  def transform_frame(env, skin, frame):
88
- #frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
101
+ # frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
89
102
  frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
90
103
  return frame
91
104
 
@@ -102,7 +115,7 @@ def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.n
102
115
 
103
116
 
104
117
  def render_background(env, skin, image, frame, state, action):
105
- coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
118
+ coords = (skin.pad - 5, skin.pad - 5, skin.size + 10, skin.size + 10)
106
119
  frame.fill(skin.bg)
107
120
  frame.blit(image, coords)
108
121
  pygame.draw.rect(frame, skin.fg, coords, 3)
@@ -116,6 +129,7 @@ def render_action(env, skin, image, frame, state, action):
116
129
  def render_bullet(env, skin, image, frame, state, action):
117
130
  return frame
118
131
 
132
+
119
133
  def render_agents(env, skin, image, frame, state, action):
120
134
  units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
121
135
  for idx, (pos, team, kind, health) in enumerate(zip(*units)):
@@ -136,7 +150,11 @@ def text_fn(text):
136
150
 
137
151
  def image_fn(skin: Skin): # TODO:
138
152
  """Create an image for background (basemap or maskmap)"""
139
- motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
153
+ motif = cv2.resize(
154
+ np.array(skin.maskmap.T),
155
+ (skin.size, skin.size),
156
+ interpolation=cv2.INTER_NEAREST,
157
+ ).astype(np.uint8)
140
158
  motif = (motif > 0).astype(np.uint8)
141
159
  image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
142
160
  image[motif == 1] = skin.fg
@@ -150,7 +168,9 @@ def unbatch_fn(state_seq, action_seq):
150
168
  if is_multi_run(state_seq):
151
169
  n_envs = state_seq[0][1].unit_positions.shape[0]
152
170
  state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
153
- action_seq_seq = [jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)]
171
+ action_seq_seq = [
172
+ jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)
173
+ ]
154
174
  else:
155
175
  state_seq_seq = [state_seq]
156
176
  action_seq_seq = [action_seq]
@@ -161,7 +181,9 @@ def expand_fn(env, state_seq, action_seq):
161
181
  """Expand the state sequence"""
162
182
  fn = env.expand_state_seq
163
183
  state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
164
- action_seq = [action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))]
184
+ action_seq = [
185
+ action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))
186
+ ]
165
187
  return state_seq, action_seq
166
188
 
167
189
 
@@ -1,35 +1,37 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.26
3
+ Version: 0.3.0
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
7
7
  Keywords: warfare,simulation,parallel,environment
8
8
  Author: Noah Syrkis
9
9
  Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
10
+ Requires-Python: >=3.11,<3.12
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
14
+ Requires-Dist: cartopy (>=0.23.0,<0.24.0)
15
15
  Requires-Dist: contextily (>=1.6.0,<2.0.0)
16
16
  Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
17
+ Requires-Dist: einops (>=0.8.0,<0.9.0)
17
18
  Requires-Dist: folium (>=0.17.0,<0.18.0)
18
- Requires-Dist: geopandas (>=1.0.0,<2.0.0)
19
19
  Requires-Dist: geopy (>=2.4.1,<3.0.0)
20
20
  Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
21
21
  Requires-Dist: jax (==0.4.17)
22
22
  Requires-Dist: jaxmarl (==0.0.3)
23
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
23
24
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
24
25
  Requires-Dist: moviepy (>=1.0.3,<2.0.0)
25
26
  Requires-Dist: numpy (<2)
26
27
  Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
27
- Requires-Dist: osmnx (>=1.9.3,<2.0.0)
28
+ Requires-Dist: osmnx (==2.0.0b0)
28
29
  Requires-Dist: pandas (>=2.2.2,<3.0.0)
29
30
  Requires-Dist: poetry (>=1.8.3,<2.0.0)
30
31
  Requires-Dist: pygame (>=2.5.2,<3.0.0)
31
32
  Requires-Dist: rasterio (>=1.3.10,<2.0.0)
32
33
  Requires-Dist: seaborn (>=0.13.2,<0.14.0)
34
+ Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
33
35
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
34
36
  Project-URL: Repository, https://github.com/syrkis/parabellum
35
37
  Description-Content-Type: text/markdown
@@ -0,0 +1,13 @@
1
+ parabellum/__init__.py,sha256=vqQbvsTT_zcLThZ7fLoJ6cMAZbEeGIJDFyCkHmovfOY,392
2
+ parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
3
+ parabellum/env.py,sha256=VV3VK7TTkianihqJopRbY0vlRWOquu-VTrc9ep0PSTk,21304
4
+ parabellum/geo.py,sha256=xkj6iJqN076tRbaG38Sq7gtwKSNzxI37msRLnpn5JV0,3561
5
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
6
+ parabellum/map.py,sha256=9AV0PIqInXcWWojzHshy3X42Nm3ZDq0O1NG-6fQ9Wgw,3345
7
+ parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
8
+ parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
9
+ parabellum/tps.py,sha256=3tVqo42ggE8idZn500C0X2pS9TmYndgBzlAG7Yj2Wz8,252
10
+ parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
11
+ parabellum-0.3.0.dist-info/METADATA,sha256=FugXwz25bAPYKlIfqFc7dGVtPupse5zHYapmqBWopE8,2740
12
+ parabellum-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ parabellum-0.3.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
- parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
- parabellum/env.py,sha256=H1YGAtUYNJd8OHnZ3sOEXbag5L0WjtJHBGL8ymGPvoE,16898
4
- parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
- parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
6
- parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
- parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
8
- parabellum-0.2.26.dist-info/METADATA,sha256=AJwdmHRRPG2MosgufeWqQH7LsDRtTvFFGMh1azey9zA,2671
9
- parabellum-0.2.26.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
- parabellum-0.2.26.dist-info/RECORD,,