parabellum 0.4.0__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,22 +1,8 @@
1
- from .env import Environment, Scenario, make_scenario, State
2
- from .vis import Visualizer, Skin
3
- from .gun import bullet_fn
4
- from . import vis
5
- from . import terrain_db
6
- from . import env
7
- from . import tps
8
- # from .run import run
1
+ from . import aid, env, geo, types
9
2
 
10
3
  __all__ = [
4
+ "geo",
11
5
  "env",
12
- "terrain_db",
13
- "vis",
14
- "tps",
15
- "Environment",
16
- "Scenario",
17
- "make_scenario",
18
- "State",
19
- "Visualizer",
20
- "Skin",
21
- "bullet_fn",
6
+ "aid",
7
+ "types",
22
8
  ]
parabellum/aid.py CHANGED
@@ -3,10 +3,9 @@
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # imports
6
- import os
7
6
  from collections import namedtuple
8
- from typing import Tuple
9
7
  import cartopy.crs as ccrs
8
+ import jax.numpy as jnp
10
9
 
11
10
  # types
12
11
  BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
@@ -23,5 +22,20 @@ def to_mercator(bbox: BBox) -> BBox:
23
22
  def to_platecarree(bbox: BBox) -> BBox:
24
23
  proj = ccrs.PlateCarree()
25
24
  west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
25
+
26
26
  east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
27
27
  return BBox(north=north, south=south, east=east, west=west)
28
+
29
+
30
+ def obstacle_mask_fn(limit):
31
+ def aux(i, j):
32
+ xs = jnp.linspace(0, i + 1, i + j + 1)
33
+ ys = jnp.linspace(0, j + 1, i + j + 1)
34
+ cc = jnp.stack((xs, ys)).astype(jnp.int8)
35
+ mask = jnp.zeros((limit, limit)).at[*cc].set(1)
36
+ return mask
37
+
38
+ x = jnp.repeat(jnp.arange(limit), limit)
39
+ y = jnp.tile(jnp.arange(limit), limit)
40
+ mask = jnp.stack([aux(*c) for c in jnp.stack((x, y)).T])
41
+ return mask.astype(jnp.int8).reshape(limit, limit, limit, limit)
parabellum/env.py CHANGED
@@ -1,571 +1,113 @@
1
- """Parabellum environment based on SMAX"""
1
+ # env.py
2
+ # parabellum env
3
+ # by: Noah Syrkis
2
4
 
5
+ # % Imports
3
6
  import jax.numpy as jnp
4
- import jax
5
- from jax import random, Array, vmap, jit
6
- from flax.struct import dataclass
7
- import chex
8
- from jaxmarl.environments.smax.smax_env import SMAX
9
-
10
- from math import ceil
11
-
12
- from typing import Tuple, Dict, cast
7
+ from jax import random, Array, lax, vmap, debug
8
+ import jax.numpy.linalg as la
9
+ from typing import Tuple
13
10
  from functools import partial
14
- from parabellum import tps, geo, terrain_db
15
-
16
-
17
- @dataclass
18
- class Scenario:
19
- """Parabellum scenario"""
20
-
21
- place: str
22
- terrain: tps.Terrain
23
- unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
24
- unit_types: chex.Array
25
- num_allies: int
26
- num_enemies: int
27
-
28
- smacv2_position_generation: bool = False
29
- smacv2_unit_type_generation: bool = False
30
-
31
-
32
- @dataclass
33
- class State:
34
- # terrain: Array
35
- unit_positions: Array # fsfds
36
- unit_alive: Array
37
- unit_teams: Array
38
- unit_health: Array
39
- unit_types: Array
40
- unit_weapon_cooldowns: Array
41
- prev_movement_actions: Array
42
- prev_attack_actions: Array
43
- time: int
44
- terminal: bool
45
-
46
-
47
- def make_scenario(
48
- place,
49
- size,
50
- unit_starting_sectors,
51
- allies_type,
52
- n_allies,
53
- enemies_type,
54
- n_enemies,
55
- ):
56
- if place in terrain_db.db:
57
- terrain = terrain_db.make_terrain(terrain_db.db[place], size)
58
- else:
59
- terrain = geo.geography_fn(place, size)
60
- if type(unit_starting_sectors) == list:
61
- default_sector = [
62
- 0,
63
- 0,
64
- size,
65
- size,
66
- ] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
67
- correct_unit_starting_sectors = []
68
- for i in range(n_allies + n_enemies):
69
- selected_sector = None
70
- for unit_ids, sector in unit_starting_sectors:
71
- if i in unit_ids:
72
- selected_sector = sector
73
- if selected_sector is None:
74
- selected_sector = default_sector
75
- correct_unit_starting_sectors.append(selected_sector)
76
- unit_starting_sectors = correct_unit_starting_sectors
77
- if type(allies_type) == int:
78
- allies = [allies_type] * n_allies
79
- else:
80
- assert len(allies_type) == n_allies
81
- allies = allies_type
82
-
83
- if type(enemies_type) == int:
84
- enemies = [enemies_type] * n_enemies
85
- else:
86
- assert len(enemies_type) == n_enemies
87
- enemies = enemies_type
88
- unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
89
- return Scenario(
90
- place,
91
- terrain,
92
- unit_starting_sectors, # type: ignore
93
- unit_types,
94
- n_allies,
95
- n_enemies,
96
- )
97
-
98
-
99
- def scenario_fn(place, size):
100
- # scenario function for Noah, cos the one above is confusing
101
- terrain = geo.geography_fn(place, size)
102
- num_allies = 10
103
- num_enemies = 10
104
- unit_types = jnp.array([0] * num_allies + [1] * num_enemies, dtype=jnp.uint8)
105
- # start units in default sectors
106
- unit_starting_sectors = jnp.array([[0, 0, 1, 1]] * (num_allies + num_enemies))
107
- return Scenario(
108
- place=place,
109
- terrain=terrain,
110
- unit_starting_sectors=unit_starting_sectors,
111
- unit_types=unit_types,
112
- num_allies=num_allies,
113
- num_enemies=num_enemies,
114
- )
115
-
116
-
117
- def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
118
- """Spawns n agents on a map."""
119
- spawn_positions = []
120
- for sector in units_spawning_sectors:
121
- rng, key_start, key_noise = random.split(rng, 3)
122
- noise = 0.25 + random.uniform(key_noise, (2,)) * 0.5
123
- idx = random.choice(key_start, sector[0].shape[0])
124
- coord = jnp.array([sector[0][idx], sector[1][idx]])
125
- spawn_positions.append(coord + noise)
126
- return jnp.array(spawn_positions, dtype=jnp.float32)
127
-
128
-
129
- def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
130
- """
131
- sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
132
- """
133
- width, height = invalid_spawn_areas.shape
134
- spawning_sectors = []
135
- for sector in sectors:
136
- coordx, coordy = (
137
- jnp.array(sector[0] * width, dtype=jnp.int32),
138
- jnp.array(sector[1] * height, dtype=jnp.int32),
139
- )
140
- sector = (
141
- invalid_spawn_areas[
142
- coordx : coordx + ceil(sector[2] * width),
143
- coordy : coordy + ceil(sector[3] * height),
144
- ]
145
- == 0
146
- )
147
- valid = jnp.nonzero(sector)
148
- if valid[0].shape[0] == 0:
149
- raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
150
- spawning_sectors.append(
151
- jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1))
152
- )
153
- return spawning_sectors
154
-
155
-
156
- class Environment(SMAX):
157
- def __init__(self, scenario: Scenario, **kwargs):
158
- map_height, map_width = scenario.terrain.building.shape
159
- args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
160
- if "unit_type_pushable" in kwargs:
161
- self.unit_type_pushable = kwargs["unit_type_pushable"]
162
- del kwargs["unit_type_pushable"]
163
- else:
164
- self.unit_type_pushable = jnp.array([1, 1, 0, 0, 0, 1])
165
- if "reset_when_done" in kwargs:
166
- self.reset_when_done = kwargs["reset_when_done"]
167
- del kwargs["reset_when_done"]
168
- else:
169
- self.reset_when_done = True
170
- super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
171
- self.terrain = scenario.terrain
172
- self.unit_starting_sectors = scenario.unit_starting_sectors
173
- # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
174
- # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
175
- # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
176
- self.scenario = scenario
177
- self.unit_type_velocities = (
178
- jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15]) / 2.5
179
- if "unit_type_velocities" not in kwargs
180
- else kwargs["unit_type_velocities"]
181
- )
182
- self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
183
- self.max_steps = 200
184
- self._push_units_away = lambda state, firmness=1: state # overwrite push units
185
- self.spawning_sectors = sectors_fn(
186
- self.unit_starting_sectors,
187
- scenario.terrain.building + scenario.terrain.water,
188
- )
189
- self.resolution = (
190
- jnp.array(jnp.max(self.unit_type_sight_ranges), dtype=jnp.int32) * 2
191
- )
192
- self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, 1))
193
-
194
- def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
195
- """Environment-specific reset."""
196
- unit_positions = spawn_fn(rng, self.spawning_sectors)
197
- unit_teams = jnp.zeros((self.num_agents,))
198
- unit_teams = unit_teams.at[self.num_allies :].set(1)
199
- unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
200
- # default behaviour spawn all marines
201
- unit_types = cast(Array, self.scenario.unit_types)
202
- unit_health = self.unit_type_health[unit_types]
203
- state = State(
204
- unit_positions=unit_positions,
205
- unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
206
- unit_teams=unit_teams,
207
- unit_health=unit_health,
208
- unit_types=unit_types,
209
- prev_movement_actions=jnp.zeros((self.num_agents, 2)),
210
- prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
211
- time=0,
212
- terminal=False,
213
- unit_weapon_cooldowns=unit_weapon_cooldowns,
214
- # terrain=self.terrain,
215
- )
216
- state = self._push_units_away(state) # type: ignore could be slow
217
- obs = self.get_obs(state)
218
- # remove world_state from obs
219
- world_state = self.get_world_state(state)
220
- obs["world_state"] = jax.lax.stop_gradient(world_state)
221
- return obs, state
222
-
223
- # def step_env(self, rng, state: State, action: Array): # type: ignore
224
- # obs, state, rewards, dones, infos = super().step_env(rng, state, action)
225
- # delete world_state from obs
226
- # <<<<<<< HEAD
227
- # obs.pop("world_state")
228
- # if not self.reset_when_done:
229
- # for key in dones.keys():
230
- # dones[key] = False
231
- # return obs, state, rewards, dones, infos
232
- # =======
233
- obs.pop("world_state")
234
- if not self.reset_when_done:
235
- for key in dones.keys():
236
- infos[key] = dones[key]
237
- dones[key] = False
238
- return obs, state, rewards, dones, infos
239
-
240
- # >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
241
-
242
- def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
243
- """Applies observation function to state."""
244
-
245
- def get_features(i, j):
246
- """Get features of unit j as seen from unit i"""
247
- # Can just keep them symmetrical for now.
248
- # j here means 'the jth unit that is not i'
249
- # The observation is such that allies are always first
250
- # so for units in the second team we count in reverse.
251
- j = jax.lax.cond(
252
- i < self.num_allies, lambda: j, lambda: self.num_agents - j - 1
253
- )
254
- offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
255
- j_idx = jax.lax.cond(
256
- ((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
257
- lambda: j,
258
- lambda: j + offset,
259
- )
260
- empty_features = jnp.zeros(shape=(len(self.unit_features),))
261
- features = self._observe_features(state, i, j_idx)
262
- visible = (
263
- jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
264
- < self.unit_type_sight_ranges[state.unit_types[i]]
265
- )
266
- return jax.lax.cond(
267
- visible
268
- & state.unit_alive[i]
269
- & state.unit_alive[j_idx]
270
- & self.has_line_of_sight(
271
- state.unit_positions[j_idx],
272
- state.unit_positions[i],
273
- self.terrain.building + self.terrain.forest,
274
- ),
275
- lambda: features,
276
- lambda: empty_features,
277
- )
278
-
279
- get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
280
- get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
281
- other_unit_obs = get_all_features(
282
- jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
283
- )
284
- other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
285
- get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
286
- own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
287
- obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
288
- return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
289
-
290
- def _our_push_units_away(
291
- self, pos, unit_types, firmness: float = 1.0
292
- ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
293
- delta_matrix = pos[:, None] - pos[None, :]
294
- dist_matrix = (
295
- jnp.linalg.norm(delta_matrix, axis=-1)
296
- + jnp.identity(self.num_agents)
297
- + 1e-6
298
- )
299
- radius_matrix = (
300
- self.unit_type_radiuses[unit_types][:, None]
301
- + self.unit_type_radiuses[unit_types][None, :]
302
- )
303
- overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
304
- unit_positions = (
305
- pos
306
- + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
307
- )
308
- return jnp.where(
309
- self.unit_type_pushable[unit_types][:, None], unit_positions, pos
310
- )
311
-
312
- def has_line_of_sight(self, source, target, raster_input):
313
- # suppose that the target is in sight_range of source, otherwise the line of sight might miss some cells
314
- cells = jnp.array(
315
- source[:, jnp.newaxis] * self.t + (1 - self.t) * target[:, jnp.newaxis],
316
- dtype=jnp.int32,
317
- )
318
- mask = jnp.zeros(raster_input.shape).at[cells[0, :], cells[1, :]].set(1)
319
- flag = ~jnp.any(jnp.logical_and(mask, raster_input))
320
- return flag
321
-
322
- @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
323
- def _world_step( # modified version of JaxMARL's SMAX _world_step
324
- self,
325
- key: chex.PRNGKey,
326
- state: State,
327
- actions: Tuple[chex.Array, chex.Array],
328
- ) -> State:
329
- def raster_crossing(pos, new_pos, mask: jnp.ndarray):
330
- pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
331
- minimum = jnp.minimum(pos, new_pos)
332
- maximum = jnp.maximum(pos, new_pos)
333
- mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask.T, 0).T
334
- mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask.T, 0).T
335
- mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask, 0)
336
- mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask, 0)
337
- return jnp.any(mask)
338
-
339
- def update_position(idx, vec):
340
- # Compute the movements slightly strangely.
341
- # The velocities below are for diagonal directions
342
- # because these are easier to encode as actions than the four
343
- # diagonal directions. Then rotate the velocity 45
344
- # degrees anticlockwise to compute the movement.
345
- pos = cast(Array, state.unit_positions[idx])
346
- new_pos = (
347
- pos
348
- + vec
349
- * self.unit_type_velocities[state.unit_types[idx]]
350
- * self.time_per_step
351
- )
352
- # avoid going out of bounds
353
- new_pos = jnp.maximum(
354
- jnp.minimum(
355
- new_pos, jnp.array([self.map_width - 1, self.map_height - 1])
356
- ),
357
- jnp.zeros((2,)),
358
- )
359
-
360
- #######################################################################
361
- ############################################ avoid going into obstacles
362
- clash = raster_crossing(
363
- pos, new_pos, self.terrain.building + self.terrain.water
364
- )
365
- new_pos = jnp.where(clash, pos, new_pos)
366
-
367
- #######################################################################
368
- #######################################################################
369
-
370
- return new_pos
371
-
372
- #######################################################################
373
- ######################################### units close enough to get hit
374
-
375
- def bystander_fn(attacked_idx):
376
- idxs = jnp.zeros((self.num_agents,))
377
- idxs *= (
378
- jnp.linalg.norm(
379
- state.unit_positions - state.unit_positions[attacked_idx], axis=-1
380
- )
381
- < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
382
- )
383
- return idxs
384
-
385
- #######################################################################
386
- #######################################################################
387
-
388
- def update_agent_health(idx, action, key): # TODO: add attack blasts
389
- # for team 1, their attack actions are labelled in
390
- # reverse order because that is the order they are
391
- # observed in
392
- attacked_idx = jax.lax.cond(
393
- idx < self.num_allies,
394
- lambda: action + self.num_allies - self.num_movement_actions,
395
- lambda: self.num_allies - 1 - (action - self.num_movement_actions),
396
- )
397
- attacked_idx = cast(int, attacked_idx) # Cast to int
398
- # deal with no-op attack actions (i.e. agents that are moving instead)
399
- attacked_idx = jax.lax.select(
400
- action < self.num_movement_actions, idx, attacked_idx
401
- )
402
- distance = jnp.linalg.norm(
403
- state.unit_positions[idx] - state.unit_positions[attacked_idx]
404
- )
405
- attack_valid = (
406
- (distance <= self.unit_type_attack_ranges[state.unit_types[idx]])
407
- & state.unit_alive[idx]
408
- & state.unit_alive[attacked_idx]
409
- )
410
- attack_valid = attack_valid & (idx != attacked_idx)
411
- attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
412
- health_diff = jax.lax.select(
413
- attack_valid,
414
- -self.unit_type_attacks[state.unit_types[idx]],
415
- 0.0,
416
- )
417
- health_diff = jnp.where(
418
- state.unit_types[idx] == 1,
419
- health_diff
420
- * distance
421
- / self.unit_type_attack_ranges[state.unit_types[idx]],
422
- health_diff,
423
- )
424
- # design choice based on the pysc2 randomness details.
425
- # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
426
-
427
- #########################################################
428
- ############################### Add bystander health diff
429
-
430
- # bystander_idxs = bystander_fn(attacked_idx) # TODO: use
431
- # bystander_valid = (
432
- # jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
433
- # .astype(jnp.bool_) # type: ignore
434
- # .astype(jnp.float32)
435
- # )
436
- # bystander_health_diff = (
437
- # bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
438
- # )
439
-
440
- #########################################################
441
- #########################################################
442
-
443
- cooldown_deviation = jax.random.uniform(
444
- key, minval=-self.time_per_step, maxval=2 * self.time_per_step
445
- )
446
- cooldown = (
447
- self.unit_type_weapon_cooldowns[state.unit_types[idx]]
448
- + cooldown_deviation
449
- )
450
- cooldown_diff = jax.lax.select(
451
- attack_valid,
452
- # subtract the current cooldown because we are
453
- # going to add it back. This way we effectively
454
- # set the new cooldown to `cooldown`
455
- cooldown - state.unit_weapon_cooldowns[idx],
456
- -self.time_per_step,
457
- )
458
- return (
459
- health_diff,
460
- attacked_idx,
461
- cooldown_diff,
462
- # (bystander_health_diff, bystander_idxs),
463
- )
464
-
465
- def perform_agent_action(idx, action, key):
466
- movement_action, attack_action = action
467
- new_pos = update_position(idx, movement_action)
468
- health_diff, attacked_idxes, cooldown_diff = update_agent_health(
469
- idx, attack_action, key
470
- )
471
-
472
- return new_pos, (health_diff, attacked_idxes), cooldown_diff
473
-
474
- keys = jax.random.split(key, num=self.num_agents)
475
- pos, (health_diff, attacked_idxes), cooldown_diff = jax.vmap(
476
- perform_agent_action
477
- )(jnp.arange(self.num_agents), actions, keys)
478
-
479
- # units push each other
480
- new_pos = self._our_push_units_away(pos, state.unit_types)
481
- clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(
482
- pos, new_pos, self.terrain.building + self.terrain.water
483
- )
484
- pos = jax.vmap(jnp.where)(clash, pos, new_pos)
485
- # avoid going out of bounds
486
- pos = jnp.maximum(
487
- jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
488
- jnp.zeros((2,)),
489
- )
490
-
491
- # Multiple enemies can attack the same unit.
492
- # We have `(health_diff, attacked_idx)` pairs.
493
- # `jax.lax.scatter_add` aggregates these exactly
494
- # in the way we want -- duplicate idxes will have their
495
- # health differences added together. However, it is a
496
- # super thin wrapper around the XLA scatter operation,
497
- # which has this bonkers syntax and requires this dnums
498
- # parameter. The usage here was inferred from a test:
499
- # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
500
- dnums = jax.lax.ScatterDimensionNumbers(
501
- update_window_dims=(),
502
- inserted_window_dims=(0,),
503
- scatter_dims_to_operand_dims=(0,),
504
- )
505
- unit_health = jnp.maximum(
506
- jax.lax.scatter_add(
507
- state.unit_health,
508
- jnp.expand_dims(attacked_idxes, 1),
509
- health_diff,
510
- dnums,
511
- ),
512
- 0.0,
513
- )
514
-
515
- #########################################################
516
- ############################ subtracting bystander health
517
-
518
- # _, bystander_health_diff = bystander
519
- # unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
520
-
521
- #########################################################
522
- #########################################################
523
-
524
- unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
525
- # replace unit health, unit positions and unit weapon cooldowns
526
- state = state.replace( # type: ignore
527
- unit_health=unit_health,
528
- unit_positions=pos,
529
- unit_weapon_cooldowns=unit_weapon_cooldowns,
530
- )
531
- return state
532
-
533
-
534
- if __name__ == "__main__":
535
- n_envs = 4
536
-
537
- n_allies = 10
538
- scenario_kwargs = {
539
- "allies_type": 0,
540
- "n_allies": n_allies,
541
- "enemies_type": 0,
542
- "n_enemies": n_allies,
543
- "place": "Vesterbro, Copenhagen, Denmark",
544
- "size": 100,
545
- "unit_starting_sectors": [
546
- ([i for i in range(n_allies)], [0.0, 0.45, 0.1, 0.1]),
547
- ([n_allies + i for i in range(n_allies)], [0.8, 0.5, 0.1, 0.1]),
548
- ],
549
- }
550
- scenario = make_scenario(**scenario_kwargs)
551
- env = Environment(scenario)
552
- rng, reset_rng = random.split(random.PRNGKey(0))
553
- reset_key = random.split(reset_rng, n_envs)
554
- obs, state = vmap(env.reset)(reset_key)
555
- state_seq = []
556
-
557
- import time
558
11
 
559
- step = vmap(jit(env.step))
560
- tic = time.time()
561
- for i in range(10):
562
- rng, act_rng, step_rng = random.split(rng, 3)
563
- act_key = random.split(act_rng, (len(env.agents), n_envs))
564
- act = {
565
- a: vmap(env.action_space(a).sample)(act_key[i])
566
- for i, a in enumerate(env.agents)
567
- }
568
- step_key = random.split(step_rng, n_envs)
569
- state_seq.append((step_key, state, act))
570
- obs, state, reward, done, infos = step(step_key, state, act)
571
- tic = time.time()
12
+ from parabellum.geo import geography_fn
13
+ from parabellum.types import Action, State, Obs, Scene
14
+ from parabellum import aid
15
+ import equinox as eqx
16
+
17
+
18
+ # %% Dataclass ################################################################
19
+ class Env:
20
+ def __init__(self, cfg):
21
+ self.cfg = cfg
22
+
23
+ def reset(self, rng: Array, scene: Scene) -> Tuple[Obs, State]:
24
+ return init_fn(rng, self, scene)
25
+
26
+ def step(self, rng: Array, scene: Scene, state: State, action: Action) -> Tuple[Obs, State]:
27
+ return obs_fn(self, scene, state), step_fn(rng, self, scene, state, action)
28
+
29
+ @property
30
+ def num_units(self):
31
+ return sum(self.cfg.counts.allies.values()) + sum(self.cfg.counts.enemies.values())
32
+
33
+ @property
34
+ def num_allies(self):
35
+ return sum(self.cfg.counts.allies.values())
36
+
37
+ @property
38
+ def num_enemies(self):
39
+ return sum(self.cfg.counts.enemies.values())
40
+
41
+
42
+ # %% Functions ################################################################
43
+ @eqx.filter_jit
44
+ def init_fn(rng: Array, env: Env, scene: Scene) -> Tuple[Obs, State]: # initialize -----
45
+ keys = random.split(rng)
46
+ health = jnp.ones(env.num_units) * scene.unit_type_health[scene.unit_types]
47
+ pos = random.normal(keys[1], (scene.unit_types.size, 2)) * 2 + env.cfg.size / 2
48
+ state = State(unit_position=pos, unit_health=health, unit_cooldown=jnp.zeros(env.num_units)) # state --
49
+ return obs_fn(env, scene, state), state # return observation and state of agents --
50
+
51
+
52
+ @eqx.filter_jit # knn from env.cfg never changes, so we can jit it
53
+ def obs_fn(env, scene: Scene, state: State) -> Obs: # return info about neighbors ---
54
+ distances = la.norm(state.unit_position[:, None] - state.unit_position, axis=-1) # all dist --
55
+ dists, idxs = lax.approx_min_k(distances, k=env.cfg.knn)
56
+ mask = mask_fn(scene, state, dists, idxs)
57
+ health = state.unit_health[idxs] * mask
58
+ cooldown = state.unit_cooldown[idxs] * mask
59
+ unit_pos = (state.unit_position[:, None, ...] - state.unit_position[idxs]) * mask[..., None]
60
+ return Obs(unit_id=idxs, unit_pos=unit_pos, unit_health=health, unit_cooldown=cooldown)
61
+
62
+
63
+ @eqx.filter_jit
64
+ def step_fn(rng, env: Env, scene: Scene, state: State, action: Action) -> State: # update agents ---
65
+ newpos = state.unit_position + action.coord * (1 - action.kinds[..., None])
66
+ bounds = ((newpos < 0).any(axis=-1) | (newpos >= env.cfg.size).any(axis=-1))[..., None]
67
+ builds = (scene.terrain.building[*newpos.astype(jnp.int32).T] > 0)[..., None]
68
+ newpos = jnp.where(bounds | builds, state.unit_position, newpos) # use old pos if new is not valid
69
+ new_hp = blast_fn(rng, env, scene, state, action)
70
+ return State(unit_position=newpos, unit_health=new_hp, unit_cooldown=state.unit_cooldown) # return -
71
+
72
+
73
+ def blast_fn(rng, env: Env, scene: Scene, state: State, action: Action): # update agents ---
74
+ dist = la.norm(state.unit_position[None, ...] - (state.unit_position + action.coord)[:, None, ...], axis=-1)
75
+ hits = dist <= scene.unit_type_reach[scene.unit_types][None, ...] * action.kinds[..., None] # mask non attack act
76
+ damage = (hits * scene.unit_type_damage[scene.unit_types][None, ...]).sum(axis=-1)
77
+ return state.unit_health - damage
78
+
79
+
80
+ # @eqx.filter_jit
81
+ def scene_fn(cfg): # init's a scene
82
+ aux = lambda key: jnp.array([x[key] for x in sorted(cfg.types, key=lambda x: x.name)]) # noqa
83
+ attrs = ["health", "damage", "reload", "reach", "sight", "speed"]
84
+ kwargs = {f"unit_type_{a}": aux(a) for a in attrs} | {"terrain": geography_fn(cfg.place, cfg.size)}
85
+ num_allies, num_enemies = sum(cfg.counts.allies.values()), sum(cfg.counts.enemies.values())
86
+ unit_teams = jnp.concat((jnp.zeros(num_allies), jnp.ones(num_enemies))).astype(jnp.int32)
87
+ aux = lambda t: jnp.concat([jnp.zeros(x) + i for i, x in enumerate([x[1] for x in sorted(cfg.counts[t].items())])]) # noqa
88
+ unit_types = jnp.concat((aux("allies"), aux("enemies"))).astype(jnp.int32)
89
+ mask = aid.obstacle_mask_fn(max([x["sight"] for x in cfg.types]))
90
+ return Scene(unit_teams=unit_teams, unit_types=unit_types, mask=mask, **kwargs) # type: ignore
91
+
92
+
93
+ @eqx.filter_jit
94
+ def mask_fn(scene, state, dists, idxs):
95
+ mask = dists < scene.unit_type_sight[scene.unit_types][..., None] # mask for removing hidden
96
+ mask = mask | obstacle_fn(scene, state.unit_position[idxs].astype(jnp.int8))
97
+ return mask
98
+
99
+
100
+ @partial(vmap, in_axes=(None, 0)) # 5 x 2 # not the best name for a fn
101
+ def obstacle_fn(scene, pos):
102
+ slice = slice_fn(scene, pos[0], pos)
103
+ return slice
104
+
105
+
106
+ @partial(vmap, in_axes=(None, None, 0))
107
+ def slice_fn(scene, source, target): # returns a 10 x 10 view with unit at top left corner, and terrain downwards
108
+ delta = ((source - target) >= 0) * 2 - 1
109
+ coord = jnp.sort(jnp.stack((source, source + delta * 10)), axis=0)[0]
110
+ slice = lax.dynamic_slice(scene.terrain.building, coord, (scene.mask.shape[-1], scene.mask.shape[-1]))
111
+ slice = lax.cond(delta[0] == 1, lambda: jnp.flip(slice), lambda: slice)
112
+ slice = lax.cond(delta[1] == 1, lambda: jnp.flip(slice, axis=1), lambda: slice)
113
+ return (scene.mask[*jnp.abs(source - target)] * slice).sum() == 0
parabellum/geo.py CHANGED
@@ -3,25 +3,23 @@
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # %% Imports
6
- from parabellum import tps
7
- import rasterio
8
6
  from rasterio import features, transform
7
+ from jax import tree
9
8
  from geopy.geocoders import Nominatim
10
9
  from geopy.distance import distance
11
10
  import contextily as cx
12
11
  import jax.numpy as jnp
13
12
  import cartopy.crs as ccrs
14
13
  from jaxtyping import Array
15
- import numpy as np
16
14
  from shapely import box
17
15
  import osmnx as ox
18
16
  import geopandas as gpd
19
17
  from collections import namedtuple
20
18
  from typing import Tuple
21
19
  import matplotlib.pyplot as plt
22
- import seaborn as sns
23
- import os
20
+ from cachier import cachier
24
21
  from jax.scipy.signal import convolve
22
+ from parabellum.types import Terrain
25
23
 
26
24
  # %% Types
27
25
  Coords = Tuple[float, float]
@@ -79,27 +77,23 @@ def basemap_fn(bbox: BBox, gdf) -> Array:
79
77
  return image
80
78
 
81
79
 
82
- def geography_fn(place, buffer=400):
80
+ @cachier()
81
+ def geography_fn(place, buffer) -> Terrain:
83
82
  bbox = get_bbox(place, buffer)
84
83
  map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
85
84
  gdf = gpd.GeoDataFrame(map_data)
86
- gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
87
- "EPSG:3857"
88
- )
85
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
89
86
  raster = raster_fn(gdf, shape=(buffer, buffer))
90
87
  basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
91
- # 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
92
88
  kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
93
- trans = lambda x: jnp.rot90(x, 3)
94
- # <<<<<<< HEAD
95
- terrain = tps.Terrain(
89
+ trans = lambda x: jnp.rot90(x, 3) # noqa
90
+ terrain = Terrain(
96
91
  building=trans(raster[0]),
97
- water=trans(
98
- raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
99
- ),
92
+ water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
100
93
  forest=trans(jnp.logical_or(raster[3], raster[4])),
101
94
  basemap=basemap,
102
95
  )
96
+ terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
103
97
  return terrain
104
98
 
105
99
 
@@ -128,25 +122,25 @@ def feature_fn(t, feature, gdf, shape):
128
122
 
129
123
 
130
124
  # %%
131
- def normalize(x):
132
- return (np.array(x) - m) / (M - m)
133
-
134
-
135
- def get_bridges(gdf):
136
- xmin, ymin, xmax, ymax = gdf.total_bounds
137
- m = np.array([xmin, ymin])
138
- M = np.array([xmax, ymax])
139
-
140
- bridges = {}
141
- for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
142
- if type(bridge["name"]) == str:
143
- bridges[idx[1]] = {
144
- "name": bridge["name"],
145
- "coords": normalize(
146
- [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
147
- ),
148
- }
149
- return bridges
125
+ # def normalize(x):
126
+ # return (np.array(x) - m) / (M - m)
127
+
128
+
129
+ # def get_bridges(gdf):
130
+ # xmin, ymin, xmax, ymax = gdf.total_bounds
131
+ # m = np.array([xmin, ymin])
132
+ # M = np.array([xmax, ymax])
133
+
134
+ # bridges = {}
135
+ # for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
136
+ # if type(bridge["name"]) == str:
137
+ # bridges[idx[1]] = {
138
+ # "name": bridge["name"],
139
+ # "coords": normalize(
140
+ # [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
141
+ # ),
142
+ # }
143
+ # return bridges
150
144
 
151
145
 
152
146
  """
parabellum/model.py ADDED
@@ -0,0 +1,6 @@
1
+ # model.py
2
+ # jax model for mapping from observation to action
3
+ # by: Noah Syrkis
4
+
5
+ # Imports
6
+ import jax.numpy as jnp
parabellum/ppo.py ADDED
@@ -0,0 +1 @@
1
+
parabellum/terrain_db.py CHANGED
@@ -1,22 +1,22 @@
1
1
  # %%
2
2
  import numpy as np
3
3
  import jax.numpy as jnp
4
- from parabellum import tps
4
+ from parabellum.types import Terrain
5
5
 
6
6
 
7
7
  # %%
8
8
  def map_raster_from_line(raster, line, size):
9
9
  x0, y0, dx, dy = line
10
- x0 = int(x0*size)
11
- y0 = int(y0*size)
12
- dx = int(dx*size)
13
- dy = int(dy*size)
10
+ x0 = int(x0 * size)
11
+ y0 = int(y0 * size)
12
+ dx = int(dx * size)
13
+ dy = int(dy * size)
14
14
  max_T = int(2**0.5 * size)
15
- for t in range(max_T+1):
16
- alpha = t/float(max_T)
17
- x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0+dx))
18
- y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0+dy))
19
- if 0<=x<size and 0<=y<size:
15
+ for t in range(max_T + 1):
16
+ alpha = t / float(max_T)
17
+ x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0 + dx))
18
+ y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0 + dy))
19
+ if 0 <= x < size and 0 <= y < size:
20
20
  raster[x, y] = 1
21
21
  return raster
22
22
 
@@ -24,20 +24,21 @@ def map_raster_from_line(raster, line, size):
24
24
  # %%
25
25
  def map_raster_from_rect(raster, rect, size):
26
26
  x0, y0, dx, dy = rect
27
- x0 = int(x0*size)
28
- y0 = int(y0*size)
29
- dx = int(dx*size)
30
- dy = int(dy*size)
31
- raster[x0:x0+dx, y0:y0+dy] = 1
27
+ x0 = int(x0 * size)
28
+ y0 = int(y0 * size)
29
+ dx = int(dx * size)
30
+ dy = int(dy * size)
31
+ raster[x0 : x0 + dx, y0 : y0 + dy] = 1
32
32
  return raster
33
33
 
34
34
 
35
35
  # %%
36
- building_color = jnp.array([201,199,198, 255])
36
+ building_color = jnp.array([201, 199, 198, 255])
37
37
  water_color = jnp.array([193, 237, 254, 255])
38
- forest_color = jnp.array([197,214,185, 255])
38
+ forest_color = jnp.array([197, 214, 185, 255])
39
39
  empty_color = jnp.array([255, 255, 255, 255])
40
40
 
41
+
41
42
  def make_terrain(terrain_args, size):
42
43
  args = {}
43
44
  for key, config in terrain_args.items():
@@ -49,44 +50,75 @@ def make_terrain(terrain_args, size):
49
50
  elif "rect" in elem:
50
51
  raster = map_raster_from_rect(raster, elem["rect"], size)
51
52
  args[key] = jnp.array(raster.T)
52
- basemap = jnp.where(args["building"][:,:,None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size,size, 1)))
53
- basemap = jnp.where(args["water"][:,:,None], jnp.tile(water_color, (size, size, 1)), basemap)
54
- basemap = jnp.where(args["forest"][:,:,None], jnp.tile(forest_color, (size, size, 1)), basemap)
53
+ basemap = jnp.where(
54
+ args["building"][:, :, None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size, size, 1))
55
+ )
56
+ basemap = jnp.where(args["water"][:, :, None], jnp.tile(water_color, (size, size, 1)), basemap)
57
+ basemap = jnp.where(args["forest"][:, :, None], jnp.tile(forest_color, (size, size, 1)), basemap)
55
58
  args["basemap"] = basemap
56
- return tps.Terrain(**args)
59
+ return Terrain(**args)
57
60
 
58
61
 
59
62
  # %%
60
63
  db = {
61
- "blank": {'building': None, 'water': None, 'forest': None},
62
- "F": {'building': [{"line": [0.25, 0.33, 0.5, 0]}, {"line":[0.75, 0.33, 0., 0.25]}, {"line":[0.50, 0.33, 0., 0.25]}], 'water': None, 'forest': None},
63
- "stronghold": {'building': [
64
- {"line":[0.2, 0.275, 0.2, 0.]}, {"line":[0.2, 0.275, 0.0, 0.2]},
65
- {"line":[0.4, 0.275, 0.0, 0.2]}, {"line":[0.2, 0.475, 0.2, 0.]},
66
-
67
- {"line":[0.2, 0.525, 0.2, 0.]}, {"line": [0.2, 0.525, 0.0, 0.2]},
68
- {"line":[0.4, 0.525, 0.0, 0.2]}, {"line": [0.2, 0.725, 0.525, 0.]},
69
-
70
- {"line":[0.75, 0.25, 0., 0.2]}, {"line":[0.75, 0.55, 0., 0.19]},
71
- {"line":[0.6, 0.25, 0.15, 0.]}], 'water': None, 'forest': None},
72
- "playground": {'building': [{"line":[0.5, 0.5, 0.5, 0.]}], 'water': None, 'forest': None},
64
+ "blank": {"building": None, "water": None, "forest": None},
65
+ "F": {
66
+ "building": [
67
+ {"line": [0.25, 0.33, 0.5, 0]},
68
+ {"line": [0.75, 0.33, 0.0, 0.25]},
69
+ {"line": [0.50, 0.33, 0.0, 0.25]},
70
+ ],
71
+ "water": None,
72
+ "forest": None,
73
+ },
74
+ "stronghold": {
75
+ "building": [
76
+ {"line": [0.2, 0.275, 0.2, 0.0]},
77
+ {"line": [0.2, 0.275, 0.0, 0.2]},
78
+ {"line": [0.4, 0.275, 0.0, 0.2]},
79
+ {"line": [0.2, 0.475, 0.2, 0.0]},
80
+ {"line": [0.2, 0.525, 0.2, 0.0]},
81
+ {"line": [0.2, 0.525, 0.0, 0.2]},
82
+ {"line": [0.4, 0.525, 0.0, 0.2]},
83
+ {"line": [0.2, 0.725, 0.525, 0.0]},
84
+ {"line": [0.75, 0.25, 0.0, 0.2]},
85
+ {"line": [0.75, 0.55, 0.0, 0.19]},
86
+ {"line": [0.6, 0.25, 0.15, 0.0]},
87
+ ],
88
+ "water": None,
89
+ "forest": None,
90
+ },
91
+ "playground": {"building": [{"line": [0.5, 0.5, 0.5, 0.0]}], "water": None, "forest": None},
73
92
  "playground2": {
74
- 'building': [],
75
- "water": [{"rect":[0., 0.8, 0.1, 0.1]}, {"rect": [0.2, 0.8, 0.8, 0.1]}],
76
- "forest": [{"rect": [0., 0., 1., 0.2]}]
93
+ "building": [],
94
+ "water": [{"rect": [0.0, 0.8, 0.1, 0.1]}, {"rect": [0.2, 0.8, 0.8, 0.1]}],
95
+ "forest": [{"rect": [0.0, 0.0, 1.0, 0.2]}],
96
+ },
97
+ "triangle": {
98
+ "building": [{"line": [0.33, 0.0, 0.0, 1.0]}, {"line": [0.66, 0.0, 0.0, 1.0]}],
99
+ "water": None,
100
+ "forest": None,
77
101
  },
78
- "triangle": {'building': [{"line": [0.33, 0., 0., 1.]}, {"line": [0.66, 0., 0., 1.]}], 'water': None, 'forest': None},
79
102
  "u_shape": {
80
- 'building': [],
81
- "water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
82
- "forest": []
103
+ "building": [],
104
+ "water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
105
+ "forest": [],
83
106
  },
84
107
  "bridges": {
85
- 'building': [],
86
- "water": [{"rect": [0.475, 0., 0.05, 0.1]}, {"rect": [0.475, 0.15, 0.05, 0.575]}, {"rect": [0.475, 0.775, 0.05, 1.]},
87
- {"rect": [0., 0.475, 0.225, 0.05]}, {"rect": [0.275, 0.475, 0.45, 0.05]}, {"rect": [0.775, 0.475, 0.23, 0.05]}],
88
- "forest": [{"rect": [0.1, 0.625, 0.275, 0.275]}, {"rect": [0.725, 0., 0.3, 0.275]}, ]
89
- }
108
+ "building": [],
109
+ "water": [
110
+ {"rect": [0.475, 0.0, 0.05, 0.1]},
111
+ {"rect": [0.475, 0.15, 0.05, 0.575]},
112
+ {"rect": [0.475, 0.775, 0.05, 1.0]},
113
+ {"rect": [0.0, 0.475, 0.225, 0.05]},
114
+ {"rect": [0.275, 0.475, 0.45, 0.05]},
115
+ {"rect": [0.775, 0.475, 0.23, 0.05]},
116
+ ],
117
+ "forest": [
118
+ {"rect": [0.1, 0.625, 0.275, 0.275]},
119
+ {"rect": [0.725, 0.0, 0.3, 0.275]},
120
+ ],
121
+ },
90
122
  }
91
123
 
92
124
  # %% [raw]
@@ -128,7 +160,7 @@ if __name__ == "__main__":
128
160
  plt.imshow(jnp.rot90(terrain.basemap))
129
161
  bl = (39.5, 5)
130
162
  tr = (44.5, 10)
131
- plt.scatter(bl[0], 49-bl[1])
132
- plt.scatter(tr[0], 49-tr[1], marker="+")
163
+ plt.scatter(bl[0], 49 - bl[1])
164
+ plt.scatter(tr[0], 49 - tr[1], marker="+")
133
165
 
134
166
  # %%
parabellum/types.py ADDED
@@ -0,0 +1,54 @@
1
+ # types.py
2
+ # parabellum types
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ from chex import dataclass
7
+ from jaxtyping import Array, Bool, Float16
8
+
9
+
10
+ # dataclasses
11
+ @dataclass
12
+ class State:
13
+ unit_position: Array
14
+ unit_health: Array
15
+ unit_cooldown: Array
16
+
17
+
18
+ @dataclass
19
+ class Obs:
20
+ unit_id: Array
21
+ unit_pos: Array
22
+ unit_health: Array
23
+ unit_cooldown: Array
24
+
25
+
26
+ @dataclass
27
+ class Action:
28
+ coord: Float16[Array, "... 2"] # noqa
29
+ kinds: Bool[Array, "..."]
30
+
31
+
32
+ @dataclass
33
+ class Terrain:
34
+ building: Array
35
+ water: Array
36
+ forest: Array
37
+ basemap: Array
38
+
39
+
40
+ @dataclass
41
+ class Scene:
42
+ terrain: Terrain
43
+ mask: Array
44
+
45
+ unit_types: Array
46
+ unit_teams: Array
47
+
48
+ unit_type_health: Array
49
+ unit_type_damage: Array
50
+ unit_type_reload: Array
51
+
52
+ unit_type_reach: Array
53
+ unit_type_sight: Array
54
+ unit_type_speed: Array
@@ -1,39 +1,41 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: parabellum
3
- Version: 0.4.0
3
+ Version: 0.5.13
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
5
  Author: Noah Syrkis
9
6
  Author-email: desk@syrkis.com
10
7
  Requires-Python: >=3.11,<3.12
11
- Classifier: License :: OSI Approved :: MIT License
12
8
  Classifier: Programming Language :: Python :: 3
13
9
  Classifier: Programming Language :: Python :: 3.11
10
+ Requires-Dist: brax (>=0.12.1,<0.13.0)
11
+ Requires-Dist: cachier (>=3.1.2,<4.0.0)
14
12
  Requires-Dist: cartopy (>=0.23.0,<0.24.0)
15
13
  Requires-Dist: contextily (>=1.6.0,<2.0.0)
16
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
14
+ Requires-Dist: distrax (>=0.1.5,<0.2.0)
17
15
  Requires-Dist: einops (>=0.8.0,<0.9.0)
16
+ Requires-Dist: equinox (>=0.11.11,<0.12.0)
17
+ Requires-Dist: evosax (>=0.1.6,<0.2.0)
18
+ Requires-Dist: flashbax (>=0.1.2,<0.2.0)
19
+ Requires-Dist: flax (>=0.10.4,<0.11.0)
18
20
  Requires-Dist: folium (>=0.17.0,<0.18.0)
19
21
  Requires-Dist: geopy (>=2.4.1,<3.0.0)
22
+ Requires-Dist: gymnax (>=0.0.8,<0.0.9)
20
23
  Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
21
- Requires-Dist: jax (==0.4.17)
22
- Requires-Dist: jaxmarl (==0.0.3)
24
+ Requires-Dist: jax (>=0.5.0,<0.6.0)
25
+ Requires-Dist: jax-tqdm (>=0.3.1,<0.4.0)
23
26
  Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
24
27
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
25
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
26
- Requires-Dist: numpy (<2)
27
- Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
28
+ Requires-Dist: navix (>=0.7.0,<0.8.0)
29
+ Requires-Dist: numpy (>=2.2.3,<3.0.0)
30
+ Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
31
+ Requires-Dist: optax (>=0.2.4,<0.3.0)
28
32
  Requires-Dist: osmnx (==2.0.0b0)
29
33
  Requires-Dist: pandas (>=2.2.2,<3.0.0)
30
34
  Requires-Dist: poetry (>=1.8.3,<2.0.0)
31
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
32
35
  Requires-Dist: rasterio (>=1.3.10,<2.0.0)
33
- Requires-Dist: seaborn (>=0.13.2,<0.14.0)
34
36
  Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
35
37
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
36
- Project-URL: Repository, https://github.com/syrkis/parabellum
38
+ Requires-Dist: wandb (>=0.19.7,<0.20.0)
37
39
  Description-Content-Type: text/markdown
38
40
 
39
41
  # Parabellum
@@ -0,0 +1,15 @@
1
+ parabellum/__init__.py,sha256=Og0bpKlQtkWCJ1yaQk98LniI1sG3B7GAR7aPMFC-v74,96
2
+ parabellum/aid.py,sha256=hCp-eDONcloKMNpYni61cE6St_67Lks2ivS5eJjSHFQ,1379
3
+ parabellum/env.py,sha256=NmDXpjzwun1PsgaHPcIH-i4H8TfkY0Iv1sEy2PKoMnA,5357
4
+ parabellum/geo.py,sha256=Zv2GCwaYKsfB6CE_cwEHhddU9p9Z5e5lL16iNfETVmw,6181
5
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
6
+ parabellum/model.py,sha256=o40jW2vp3Fwxt1KqykZ-qZXs73-H24nA4iiRoetAobA,117
7
+ parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
8
+ parabellum/ppo.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
9
+ parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
10
+ parabellum/terrain_db.py,sha256=lPd56Qe4_xuVFJcKsftE9RvNN9nPr69GzBJZIS5nKrY,5207
11
+ parabellum/types.py,sha256=WwMzSQ5qnRo6rqKTHIGlWPcg57lx6ENDSK5DFr3R7-s,832
12
+ parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
13
+ parabellum-0.5.13.dist-info/METADATA,sha256=g_tMs-pztJXVRE__E9VLORyZANegyalSoT7ilMz5oRc,2772
14
+ parabellum-0.5.13.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
15
+ parabellum-0.5.13.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.0
2
+ Generator: poetry-core 2.1.2
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
parabellum/tps.py DELETED
@@ -1,17 +0,0 @@
1
- # %%
2
- # tps.py
3
- # parabellum types and dataclasses
4
- # by: Noah Syrkis
5
-
6
- # %% Imports
7
- from chex import dataclass
8
- from jaxtyping import Array
9
-
10
-
11
- # %% Dataclasses
12
- @dataclass
13
- class Terrain:
14
- building: Array
15
- water: Array
16
- forest: Array
17
- basemap: Array
@@ -1,13 +0,0 @@
1
- parabellum/__init__.py,sha256=hIOLir7wgaf_HU4j8uos7PaCrofqPQcr3FcMlBsZyr8,406
2
- parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
3
- parabellum/env.py,sha256=2FDOI90IuoOTFV5DhLLWWMpuaj4mcwoPup24KK-duYI,22907
4
- parabellum/geo.py,sha256=PJs9UevibuokDVb3oJWNHvYHlMYGCxB5OkNSbDj48vI,6198
5
- parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
6
- parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
7
- parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
8
- parabellum/terrain_db.py,sha256=5lHzbX94lzkb--cETpraXS42G4T4tKekISpTm4yaYEE,4748
9
- parabellum/tps.py,sha256=of-RBdelAbNCHQZd1I22RWmZkwUEh6f161mx0X_G2tE,257
10
- parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
11
- parabellum-0.4.0.dist-info/METADATA,sha256=GrTOPfKE1HHK2R3RR7gpYgLBFzSIgOgDTpt7h3zFlCM,2707
12
- parabellum-0.4.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- parabellum-0.4.0.dist-info/RECORD,,