parabellum 0.4.0__tar.gz → 0.5.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,39 +1,41 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: parabellum
3
- Version: 0.4.0
3
+ Version: 0.5.13
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
5
  Author: Noah Syrkis
9
6
  Author-email: desk@syrkis.com
10
7
  Requires-Python: >=3.11,<3.12
11
- Classifier: License :: OSI Approved :: MIT License
12
8
  Classifier: Programming Language :: Python :: 3
13
9
  Classifier: Programming Language :: Python :: 3.11
10
+ Requires-Dist: brax (>=0.12.1,<0.13.0)
11
+ Requires-Dist: cachier (>=3.1.2,<4.0.0)
14
12
  Requires-Dist: cartopy (>=0.23.0,<0.24.0)
15
13
  Requires-Dist: contextily (>=1.6.0,<2.0.0)
16
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
14
+ Requires-Dist: distrax (>=0.1.5,<0.2.0)
17
15
  Requires-Dist: einops (>=0.8.0,<0.9.0)
16
+ Requires-Dist: equinox (>=0.11.11,<0.12.0)
17
+ Requires-Dist: evosax (>=0.1.6,<0.2.0)
18
+ Requires-Dist: flashbax (>=0.1.2,<0.2.0)
19
+ Requires-Dist: flax (>=0.10.4,<0.11.0)
18
20
  Requires-Dist: folium (>=0.17.0,<0.18.0)
19
21
  Requires-Dist: geopy (>=2.4.1,<3.0.0)
22
+ Requires-Dist: gymnax (>=0.0.8,<0.0.9)
20
23
  Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
21
- Requires-Dist: jax (==0.4.17)
22
- Requires-Dist: jaxmarl (==0.0.3)
24
+ Requires-Dist: jax (>=0.5.0,<0.6.0)
25
+ Requires-Dist: jax-tqdm (>=0.3.1,<0.4.0)
23
26
  Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
24
27
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
25
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
26
- Requires-Dist: numpy (<2)
27
- Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
28
+ Requires-Dist: navix (>=0.7.0,<0.8.0)
29
+ Requires-Dist: numpy (>=2.2.3,<3.0.0)
30
+ Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
31
+ Requires-Dist: optax (>=0.2.4,<0.3.0)
28
32
  Requires-Dist: osmnx (==2.0.0b0)
29
33
  Requires-Dist: pandas (>=2.2.2,<3.0.0)
30
34
  Requires-Dist: poetry (>=1.8.3,<2.0.0)
31
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
32
35
  Requires-Dist: rasterio (>=1.3.10,<2.0.0)
33
- Requires-Dist: seaborn (>=0.13.2,<0.14.0)
34
36
  Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
35
37
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
36
- Project-URL: Repository, https://github.com/syrkis/parabellum
38
+ Requires-Dist: wandb (>=0.19.7,<0.20.0)
37
39
  Description-Content-Type: text/markdown
38
40
 
39
41
  # Parabellum
@@ -0,0 +1,8 @@
1
+ from . import aid, env, geo, types
2
+
3
+ __all__ = [
4
+ "geo",
5
+ "env",
6
+ "aid",
7
+ "types",
8
+ ]
@@ -3,10 +3,9 @@
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # imports
6
- import os
7
6
  from collections import namedtuple
8
- from typing import Tuple
9
7
  import cartopy.crs as ccrs
8
+ import jax.numpy as jnp
10
9
 
11
10
  # types
12
11
  BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
@@ -23,5 +22,20 @@ def to_mercator(bbox: BBox) -> BBox:
23
22
  def to_platecarree(bbox: BBox) -> BBox:
24
23
  proj = ccrs.PlateCarree()
25
24
  west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
25
+
26
26
  east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
27
27
  return BBox(north=north, south=south, east=east, west=west)
28
+
29
+
30
+ def obstacle_mask_fn(limit):
31
+ def aux(i, j):
32
+ xs = jnp.linspace(0, i + 1, i + j + 1)
33
+ ys = jnp.linspace(0, j + 1, i + j + 1)
34
+ cc = jnp.stack((xs, ys)).astype(jnp.int8)
35
+ mask = jnp.zeros((limit, limit)).at[*cc].set(1)
36
+ return mask
37
+
38
+ x = jnp.repeat(jnp.arange(limit), limit)
39
+ y = jnp.tile(jnp.arange(limit), limit)
40
+ mask = jnp.stack([aux(*c) for c in jnp.stack((x, y)).T])
41
+ return mask.astype(jnp.int8).reshape(limit, limit, limit, limit)
@@ -0,0 +1,113 @@
1
+ # env.py
2
+ # parabellum env
3
+ # by: Noah Syrkis
4
+
5
+ # % Imports
6
+ import jax.numpy as jnp
7
+ from jax import random, Array, lax, vmap, debug
8
+ import jax.numpy.linalg as la
9
+ from typing import Tuple
10
+ from functools import partial
11
+
12
+ from parabellum.geo import geography_fn
13
+ from parabellum.types import Action, State, Obs, Scene
14
+ from parabellum import aid
15
+ import equinox as eqx
16
+
17
+
18
+ # %% Dataclass ################################################################
19
+ class Env:
20
+ def __init__(self, cfg):
21
+ self.cfg = cfg
22
+
23
+ def reset(self, rng: Array, scene: Scene) -> Tuple[Obs, State]:
24
+ return init_fn(rng, self, scene)
25
+
26
+ def step(self, rng: Array, scene: Scene, state: State, action: Action) -> Tuple[Obs, State]:
27
+ return obs_fn(self, scene, state), step_fn(rng, self, scene, state, action)
28
+
29
+ @property
30
+ def num_units(self):
31
+ return sum(self.cfg.counts.allies.values()) + sum(self.cfg.counts.enemies.values())
32
+
33
+ @property
34
+ def num_allies(self):
35
+ return sum(self.cfg.counts.allies.values())
36
+
37
+ @property
38
+ def num_enemies(self):
39
+ return sum(self.cfg.counts.enemies.values())
40
+
41
+
42
+ # %% Functions ################################################################
43
+ @eqx.filter_jit
44
+ def init_fn(rng: Array, env: Env, scene: Scene) -> Tuple[Obs, State]: # initialize -----
45
+ keys = random.split(rng)
46
+ health = jnp.ones(env.num_units) * scene.unit_type_health[scene.unit_types]
47
+ pos = random.normal(keys[1], (scene.unit_types.size, 2)) * 2 + env.cfg.size / 2
48
+ state = State(unit_position=pos, unit_health=health, unit_cooldown=jnp.zeros(env.num_units)) # state --
49
+ return obs_fn(env, scene, state), state # return observation and state of agents --
50
+
51
+
52
+ @eqx.filter_jit # knn from env.cfg never changes, so we can jit it
53
+ def obs_fn(env, scene: Scene, state: State) -> Obs: # return info about neighbors ---
54
+ distances = la.norm(state.unit_position[:, None] - state.unit_position, axis=-1) # all dist --
55
+ dists, idxs = lax.approx_min_k(distances, k=env.cfg.knn)
56
+ mask = mask_fn(scene, state, dists, idxs)
57
+ health = state.unit_health[idxs] * mask
58
+ cooldown = state.unit_cooldown[idxs] * mask
59
+ unit_pos = (state.unit_position[:, None, ...] - state.unit_position[idxs]) * mask[..., None]
60
+ return Obs(unit_id=idxs, unit_pos=unit_pos, unit_health=health, unit_cooldown=cooldown)
61
+
62
+
63
+ @eqx.filter_jit
64
+ def step_fn(rng, env: Env, scene: Scene, state: State, action: Action) -> State: # update agents ---
65
+ newpos = state.unit_position + action.coord * (1 - action.kinds[..., None])
66
+ bounds = ((newpos < 0).any(axis=-1) | (newpos >= env.cfg.size).any(axis=-1))[..., None]
67
+ builds = (scene.terrain.building[*newpos.astype(jnp.int32).T] > 0)[..., None]
68
+ newpos = jnp.where(bounds | builds, state.unit_position, newpos) # use old pos if new is not valid
69
+ new_hp = blast_fn(rng, env, scene, state, action)
70
+ return State(unit_position=newpos, unit_health=new_hp, unit_cooldown=state.unit_cooldown) # return -
71
+
72
+
73
+ def blast_fn(rng, env: Env, scene: Scene, state: State, action: Action): # update agents ---
74
+ dist = la.norm(state.unit_position[None, ...] - (state.unit_position + action.coord)[:, None, ...], axis=-1)
75
+ hits = dist <= scene.unit_type_reach[scene.unit_types][None, ...] * action.kinds[..., None] # mask non attack act
76
+ damage = (hits * scene.unit_type_damage[scene.unit_types][None, ...]).sum(axis=-1)
77
+ return state.unit_health - damage
78
+
79
+
80
+ # @eqx.filter_jit
81
+ def scene_fn(cfg): # init's a scene
82
+ aux = lambda key: jnp.array([x[key] for x in sorted(cfg.types, key=lambda x: x.name)]) # noqa
83
+ attrs = ["health", "damage", "reload", "reach", "sight", "speed"]
84
+ kwargs = {f"unit_type_{a}": aux(a) for a in attrs} | {"terrain": geography_fn(cfg.place, cfg.size)}
85
+ num_allies, num_enemies = sum(cfg.counts.allies.values()), sum(cfg.counts.enemies.values())
86
+ unit_teams = jnp.concat((jnp.zeros(num_allies), jnp.ones(num_enemies))).astype(jnp.int32)
87
+ aux = lambda t: jnp.concat([jnp.zeros(x) + i for i, x in enumerate([x[1] for x in sorted(cfg.counts[t].items())])]) # noqa
88
+ unit_types = jnp.concat((aux("allies"), aux("enemies"))).astype(jnp.int32)
89
+ mask = aid.obstacle_mask_fn(max([x["sight"] for x in cfg.types]))
90
+ return Scene(unit_teams=unit_teams, unit_types=unit_types, mask=mask, **kwargs) # type: ignore
91
+
92
+
93
+ @eqx.filter_jit
94
+ def mask_fn(scene, state, dists, idxs):
95
+ mask = dists < scene.unit_type_sight[scene.unit_types][..., None] # mask for removing hidden
96
+ mask = mask | obstacle_fn(scene, state.unit_position[idxs].astype(jnp.int8))
97
+ return mask
98
+
99
+
100
+ @partial(vmap, in_axes=(None, 0)) # 5 x 2 # not the best name for a fn
101
+ def obstacle_fn(scene, pos):
102
+ slice = slice_fn(scene, pos[0], pos)
103
+ return slice
104
+
105
+
106
+ @partial(vmap, in_axes=(None, None, 0))
107
+ def slice_fn(scene, source, target): # returns a 10 x 10 view with unit at top left corner, and terrain downwards
108
+ delta = ((source - target) >= 0) * 2 - 1
109
+ coord = jnp.sort(jnp.stack((source, source + delta * 10)), axis=0)[0]
110
+ slice = lax.dynamic_slice(scene.terrain.building, coord, (scene.mask.shape[-1], scene.mask.shape[-1]))
111
+ slice = lax.cond(delta[0] == 1, lambda: jnp.flip(slice), lambda: slice)
112
+ slice = lax.cond(delta[1] == 1, lambda: jnp.flip(slice, axis=1), lambda: slice)
113
+ return (scene.mask[*jnp.abs(source - target)] * slice).sum() == 0
@@ -3,25 +3,23 @@
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # %% Imports
6
- from parabellum import tps
7
- import rasterio
8
6
  from rasterio import features, transform
7
+ from jax import tree
9
8
  from geopy.geocoders import Nominatim
10
9
  from geopy.distance import distance
11
10
  import contextily as cx
12
11
  import jax.numpy as jnp
13
12
  import cartopy.crs as ccrs
14
13
  from jaxtyping import Array
15
- import numpy as np
16
14
  from shapely import box
17
15
  import osmnx as ox
18
16
  import geopandas as gpd
19
17
  from collections import namedtuple
20
18
  from typing import Tuple
21
19
  import matplotlib.pyplot as plt
22
- import seaborn as sns
23
- import os
20
+ from cachier import cachier
24
21
  from jax.scipy.signal import convolve
22
+ from parabellum.types import Terrain
25
23
 
26
24
  # %% Types
27
25
  Coords = Tuple[float, float]
@@ -79,27 +77,23 @@ def basemap_fn(bbox: BBox, gdf) -> Array:
79
77
  return image
80
78
 
81
79
 
82
- def geography_fn(place, buffer=400):
80
+ @cachier()
81
+ def geography_fn(place, buffer) -> Terrain:
83
82
  bbox = get_bbox(place, buffer)
84
83
  map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
85
84
  gdf = gpd.GeoDataFrame(map_data)
86
- gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
87
- "EPSG:3857"
88
- )
85
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
89
86
  raster = raster_fn(gdf, shape=(buffer, buffer))
90
87
  basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
91
- # 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
92
88
  kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
93
- trans = lambda x: jnp.rot90(x, 3)
94
- # <<<<<<< HEAD
95
- terrain = tps.Terrain(
89
+ trans = lambda x: jnp.rot90(x, 3) # noqa
90
+ terrain = Terrain(
96
91
  building=trans(raster[0]),
97
- water=trans(
98
- raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
99
- ),
92
+ water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
100
93
  forest=trans(jnp.logical_or(raster[3], raster[4])),
101
94
  basemap=basemap,
102
95
  )
96
+ terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
103
97
  return terrain
104
98
 
105
99
 
@@ -128,25 +122,25 @@ def feature_fn(t, feature, gdf, shape):
128
122
 
129
123
 
130
124
  # %%
131
- def normalize(x):
132
- return (np.array(x) - m) / (M - m)
133
-
134
-
135
- def get_bridges(gdf):
136
- xmin, ymin, xmax, ymax = gdf.total_bounds
137
- m = np.array([xmin, ymin])
138
- M = np.array([xmax, ymax])
139
-
140
- bridges = {}
141
- for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
142
- if type(bridge["name"]) == str:
143
- bridges[idx[1]] = {
144
- "name": bridge["name"],
145
- "coords": normalize(
146
- [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
147
- ),
148
- }
149
- return bridges
125
+ # def normalize(x):
126
+ # return (np.array(x) - m) / (M - m)
127
+
128
+
129
+ # def get_bridges(gdf):
130
+ # xmin, ymin, xmax, ymax = gdf.total_bounds
131
+ # m = np.array([xmin, ymin])
132
+ # M = np.array([xmax, ymax])
133
+
134
+ # bridges = {}
135
+ # for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
136
+ # if type(bridge["name"]) == str:
137
+ # bridges[idx[1]] = {
138
+ # "name": bridge["name"],
139
+ # "coords": normalize(
140
+ # [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
141
+ # ),
142
+ # }
143
+ # return bridges
150
144
 
151
145
 
152
146
  """
@@ -0,0 +1,6 @@
1
+ # model.py
2
+ # jax model for mapping from observation to action
3
+ # by: Noah Syrkis
4
+
5
+ # Imports
6
+ import jax.numpy as jnp
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,166 @@
1
+ # %%
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from parabellum.types import Terrain
5
+
6
+
7
+ # %%
8
+ def map_raster_from_line(raster, line, size):
9
+ x0, y0, dx, dy = line
10
+ x0 = int(x0 * size)
11
+ y0 = int(y0 * size)
12
+ dx = int(dx * size)
13
+ dy = int(dy * size)
14
+ max_T = int(2**0.5 * size)
15
+ for t in range(max_T + 1):
16
+ alpha = t / float(max_T)
17
+ x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0 + dx))
18
+ y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0 + dy))
19
+ if 0 <= x < size and 0 <= y < size:
20
+ raster[x, y] = 1
21
+ return raster
22
+
23
+
24
+ # %%
25
+ def map_raster_from_rect(raster, rect, size):
26
+ x0, y0, dx, dy = rect
27
+ x0 = int(x0 * size)
28
+ y0 = int(y0 * size)
29
+ dx = int(dx * size)
30
+ dy = int(dy * size)
31
+ raster[x0 : x0 + dx, y0 : y0 + dy] = 1
32
+ return raster
33
+
34
+
35
+ # %%
36
+ building_color = jnp.array([201, 199, 198, 255])
37
+ water_color = jnp.array([193, 237, 254, 255])
38
+ forest_color = jnp.array([197, 214, 185, 255])
39
+ empty_color = jnp.array([255, 255, 255, 255])
40
+
41
+
42
+ def make_terrain(terrain_args, size):
43
+ args = {}
44
+ for key, config in terrain_args.items():
45
+ raster = np.zeros((size, size))
46
+ if config is not None:
47
+ for elem in config:
48
+ if "line" in elem:
49
+ raster = map_raster_from_line(raster, elem["line"], size)
50
+ elif "rect" in elem:
51
+ raster = map_raster_from_rect(raster, elem["rect"], size)
52
+ args[key] = jnp.array(raster.T)
53
+ basemap = jnp.where(
54
+ args["building"][:, :, None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size, size, 1))
55
+ )
56
+ basemap = jnp.where(args["water"][:, :, None], jnp.tile(water_color, (size, size, 1)), basemap)
57
+ basemap = jnp.where(args["forest"][:, :, None], jnp.tile(forest_color, (size, size, 1)), basemap)
58
+ args["basemap"] = basemap
59
+ return Terrain(**args)
60
+
61
+
62
+ # %%
63
+ db = {
64
+ "blank": {"building": None, "water": None, "forest": None},
65
+ "F": {
66
+ "building": [
67
+ {"line": [0.25, 0.33, 0.5, 0]},
68
+ {"line": [0.75, 0.33, 0.0, 0.25]},
69
+ {"line": [0.50, 0.33, 0.0, 0.25]},
70
+ ],
71
+ "water": None,
72
+ "forest": None,
73
+ },
74
+ "stronghold": {
75
+ "building": [
76
+ {"line": [0.2, 0.275, 0.2, 0.0]},
77
+ {"line": [0.2, 0.275, 0.0, 0.2]},
78
+ {"line": [0.4, 0.275, 0.0, 0.2]},
79
+ {"line": [0.2, 0.475, 0.2, 0.0]},
80
+ {"line": [0.2, 0.525, 0.2, 0.0]},
81
+ {"line": [0.2, 0.525, 0.0, 0.2]},
82
+ {"line": [0.4, 0.525, 0.0, 0.2]},
83
+ {"line": [0.2, 0.725, 0.525, 0.0]},
84
+ {"line": [0.75, 0.25, 0.0, 0.2]},
85
+ {"line": [0.75, 0.55, 0.0, 0.19]},
86
+ {"line": [0.6, 0.25, 0.15, 0.0]},
87
+ ],
88
+ "water": None,
89
+ "forest": None,
90
+ },
91
+ "playground": {"building": [{"line": [0.5, 0.5, 0.5, 0.0]}], "water": None, "forest": None},
92
+ "playground2": {
93
+ "building": [],
94
+ "water": [{"rect": [0.0, 0.8, 0.1, 0.1]}, {"rect": [0.2, 0.8, 0.8, 0.1]}],
95
+ "forest": [{"rect": [0.0, 0.0, 1.0, 0.2]}],
96
+ },
97
+ "triangle": {
98
+ "building": [{"line": [0.33, 0.0, 0.0, 1.0]}, {"line": [0.66, 0.0, 0.0, 1.0]}],
99
+ "water": None,
100
+ "forest": None,
101
+ },
102
+ "u_shape": {
103
+ "building": [],
104
+ "water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
105
+ "forest": [],
106
+ },
107
+ "bridges": {
108
+ "building": [],
109
+ "water": [
110
+ {"rect": [0.475, 0.0, 0.05, 0.1]},
111
+ {"rect": [0.475, 0.15, 0.05, 0.575]},
112
+ {"rect": [0.475, 0.775, 0.05, 1.0]},
113
+ {"rect": [0.0, 0.475, 0.225, 0.05]},
114
+ {"rect": [0.275, 0.475, 0.45, 0.05]},
115
+ {"rect": [0.775, 0.475, 0.23, 0.05]},
116
+ ],
117
+ "forest": [
118
+ {"rect": [0.1, 0.625, 0.275, 0.275]},
119
+ {"rect": [0.725, 0.0, 0.3, 0.275]},
120
+ ],
121
+ },
122
+ }
123
+
124
+ # %% [raw]
125
+ # import matplotlib.pyplot as plt
126
+ # size = 100
127
+ # raster = np.zeros((size, size))
128
+ # rect = [0.475, 0., 0.05, 0.1]
129
+ # raster = map_raster_from_rect(raster, rect, size)
130
+ # rect = [0.475, 0.15, 0.05, 0.575]
131
+ # raster = map_raster_from_rect(raster, rect, size)
132
+ # rect = [0.475, 0.775, 0.05, 1.]
133
+ # raster = map_raster_from_rect(raster, rect, size)
134
+ #
135
+ # rect = [0., 0.475, 0.225, 0.05]
136
+ # raster = map_raster_from_rect(raster, rect, size)
137
+ # rect = [0.275, 0.475, 0.45, 0.05]
138
+ # raster = map_raster_from_rect(raster, rect, size)
139
+ # rect = [0.775, 0.475, 0.23, 0.05]
140
+ # raster = map_raster_from_rect(raster, rect, size)
141
+ #
142
+ # rect = [0.1, 0.625, 0.275, 0.275]
143
+ # raster = map_raster_from_rect(raster, rect, size)
144
+ # rect = [0.725, 0., 0.3, 0.275]
145
+ # raster = map_raster_from_rect(raster, rect, size)
146
+ #
147
+ # plt.imshow(raster[::-1, :])
148
+
149
+ # %% [markdown]
150
+ # # Main
151
+
152
+ # %%
153
+ if __name__ == "__main__":
154
+ import matplotlib.pyplot as plt
155
+
156
+ # %%
157
+ terrain = make_terrain(db["bridges"], size=100)
158
+
159
+ # %%
160
+ plt.imshow(jnp.rot90(terrain.basemap))
161
+ bl = (39.5, 5)
162
+ tr = (44.5, 10)
163
+ plt.scatter(bl[0], 49 - bl[1])
164
+ plt.scatter(tr[0], 49 - tr[1], marker="+")
165
+
166
+ # %%
@@ -0,0 +1,54 @@
1
+ # types.py
2
+ # parabellum types
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ from chex import dataclass
7
+ from jaxtyping import Array, Bool, Float16
8
+
9
+
10
+ # dataclasses
11
+ @dataclass
12
+ class State:
13
+ unit_position: Array
14
+ unit_health: Array
15
+ unit_cooldown: Array
16
+
17
+
18
+ @dataclass
19
+ class Obs:
20
+ unit_id: Array
21
+ unit_pos: Array
22
+ unit_health: Array
23
+ unit_cooldown: Array
24
+
25
+
26
+ @dataclass
27
+ class Action:
28
+ coord: Float16[Array, "... 2"] # noqa
29
+ kinds: Bool[Array, "..."]
30
+
31
+
32
+ @dataclass
33
+ class Terrain:
34
+ building: Array
35
+ water: Array
36
+ forest: Array
37
+ basemap: Array
38
+
39
+
40
+ @dataclass
41
+ class Scene:
42
+ terrain: Terrain
43
+ mask: Array
44
+
45
+ unit_types: Array
46
+ unit_teams: Array
47
+
48
+ unit_type_health: Array
49
+ unit_type_damage: Array
50
+ unit_type_reload: Array
51
+
52
+ unit_type_reach: Array
53
+ unit_type_sight: Array
54
+ unit_type_speed: Array
@@ -0,0 +1,53 @@
1
+ [tool.poetry]
2
+ name = "parabellum"
3
+ version = "0.5.13"
4
+ description = "Parabellum environment for parallel warfare simulation"
5
+ authors = ["Noah Syrkis <desk@syrkis.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.11,<3.12"
10
+ jupyterlab = "^4.2.2"
11
+ poetry = "^1.8.3"
12
+ tqdm = "^4.66.4"
13
+ geopy = "^2.4.1"
14
+ osmnx = "2.0.0b0"
15
+ rasterio = "^1.3.10"
16
+ ipykernel = "^6.29.5"
17
+ folium = "^0.17.0"
18
+ pandas = "^2.2.2"
19
+ contextily = "^1.6.0"
20
+ einops = "^0.8.0"
21
+ jaxtyping = "^0.2.33"
22
+ cartopy = "^0.23.0"
23
+ stadiamaps = "^3.2.1"
24
+ cachier = "^3.1.2"
25
+ equinox = "^0.11.11"
26
+ jax = "^0.5.0"
27
+ gymnax = "^0.0.8"
28
+ evosax = "^0.1.6"
29
+ distrax = "^0.1.5"
30
+ optax = "^0.2.4"
31
+ flax = "^0.10.4"
32
+ numpy = "^2.2.3"
33
+ brax = "^0.12.1"
34
+ wandb = "^0.19.7"
35
+ flashbax = "^0.1.2"
36
+ navix = "^0.7.0"
37
+ omegaconf = "^2.3.0"
38
+ jax-tqdm = "^0.3.1"
39
+
40
+ [tool.poetry.group.dev.dependencies]
41
+ esch = { path = "../../esch" }
42
+ purejaxrl = { path = "../purejaxrl" }
43
+
44
+ [build-system]
45
+ requires = ["poetry-core"]
46
+ build-backend = "poetry.core.masonry.api"
47
+
48
+ [tool.pyright]
49
+ venvPath = "."
50
+ venv = ".venv"
51
+
52
+ [tool.ruff]
53
+ line-length = 120
@@ -1,22 +0,0 @@
1
- from .env import Environment, Scenario, make_scenario, State
2
- from .vis import Visualizer, Skin
3
- from .gun import bullet_fn
4
- from . import vis
5
- from . import terrain_db
6
- from . import env
7
- from . import tps
8
- # from .run import run
9
-
10
- __all__ = [
11
- "env",
12
- "terrain_db",
13
- "vis",
14
- "tps",
15
- "Environment",
16
- "Scenario",
17
- "make_scenario",
18
- "State",
19
- "Visualizer",
20
- "Skin",
21
- "bullet_fn",
22
- ]