parabellum 0.0.0__tar.gz → 0.0.73__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,16 @@
1
+ **/__pycache__
2
+ **/*.gif
3
+ **/*.svg
4
+ *.mp4
5
+ dist/
6
+ **/.DS_Store
7
+ **/.ipynb_checkpoints
8
+ cache/
9
+ pyrightconfig.json
10
+ .venv
11
+ output
12
+ parabellum/cache/
13
+ .ruff_cache/
14
+ .coverage
15
+ build/
16
+ profiler_logs/
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: parabellum
3
+ Version: 0.0.73
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Author-email: Noah Syrkis <desk@syrkis.com>
6
+ Requires-Python: <3.12,>=3.11
7
+ Requires-Dist: brax<0.13,>=0.12.1
8
+ Requires-Dist: cachier<4,>=3.1.2
9
+ Requires-Dist: cartopy<0.24,>=0.23.0
10
+ Requires-Dist: contextily<2,>=1.6.0
11
+ Requires-Dist: distrax<0.2,>=0.1.5
12
+ Requires-Dist: einops<0.9,>=0.8.0
13
+ Requires-Dist: equinox>=0.12.2
14
+ Requires-Dist: evosax<0.2,>=0.1.6
15
+ Requires-Dist: flashbax<0.2,>=0.1.2
16
+ Requires-Dist: flax<0.11,>=0.10.4
17
+ Requires-Dist: folium<0.18,>=0.17.0
18
+ Requires-Dist: geopy<3,>=2.4.1
19
+ Requires-Dist: gymnax<0.0.9,>=0.0.8
20
+ Requires-Dist: ipykernel<7,>=6.29.5
21
+ Requires-Dist: ipython>=8.36.0
22
+ Requires-Dist: jax-tqdm<0.4,>=0.3.1
23
+ Requires-Dist: jax<0.7,>=0.6.0
24
+ Requires-Dist: jaxkd>=0.1.0
25
+ Requires-Dist: jaxtyping<0.3,>=0.2.33
26
+ Requires-Dist: jupyterlab<5,>=4.2.2
27
+ Requires-Dist: navix<0.8,>=0.7.0
28
+ Requires-Dist: notebook>=7.4.2
29
+ Requires-Dist: numpy>=2
30
+ Requires-Dist: omegaconf<3,>=2.3.0
31
+ Requires-Dist: optax<0.3,>=0.2.4
32
+ Requires-Dist: osmnx==2.0.0b0
33
+ Requires-Dist: pandas<3,>=2.2.2
34
+ Requires-Dist: poetry<2,>=1.8.3
35
+ Requires-Dist: rasterio<2,>=1.3.10
36
+ Requires-Dist: stadiamaps<4,>=3.2.1
37
+ Requires-Dist: tensorboard>=2.19.0
38
+ Requires-Dist: tensorflow>=2.19.0
39
+ Requires-Dist: tqdm<5,>=4.66.4
40
+ Requires-Dist: wandb<0.20,>=0.19.7
41
+ Requires-Dist: xprof>=2.20.0
42
+ Description-Content-Type: text/markdown
43
+
44
+ # Parabellum
45
+
46
+ TODO: switch to red and blue team semantics (not enemy and ally)
@@ -0,0 +1,3 @@
1
+ # Parabellum
2
+
3
+ TODO: switch to red and blue team semantics (not enemy and ally)
@@ -0,0 +1,57 @@
1
+ # %% main.py
2
+ # parabellum main
3
+ # by: Noah Syrkis
4
+
5
+ # Imports ###################################################################
6
+ from functools import partial
7
+ from typing import Tuple
8
+ import esch
9
+
10
+ from einops import repeat, rearrange
11
+ import numpy as np
12
+ from jax import jit, lax, random, tree, vmap
13
+ from jax import numpy as jnp
14
+ from jaxtyping import Array
15
+
16
+ import parabellum as pb
17
+ from parabellum.types import Action, Config, State
18
+
19
+
20
+ # %% Functions
21
+ def action_fn(cfg: Config, rng: Array) -> Action:
22
+ pos = random.uniform(rng, (cfg.length, 2), minval=-1, maxval=1) * cfg.rules.reach[cfg.types][..., None]
23
+ kind = random.randint(rng, (cfg.length,), minval=0, maxval=3)
24
+ return Action(pos=pos, kind=kind)
25
+
26
+
27
+ def step_fn(state: State, rng: Array) -> Tuple[State, Tuple[State, Action]]:
28
+ action = action_fn(cfg, rng)
29
+ obs, state = env.step(cfg, rng, state, action)
30
+ return state, (state, action)
31
+
32
+
33
+ def traj_fn(state, rng) -> Tuple[State, Tuple[State, Action]]:
34
+ rngs = random.split(rng, cfg.steps)
35
+ return lax.scan(step_fn, state, rngs)
36
+
37
+
38
+ # %% Main
39
+ env, cfg = pb.env.Env(), Config()
40
+ init_key, traj_key = random.split(random.PRNGKey(0), (2, cfg.sims))
41
+
42
+ init = vmap(jit(partial(env.init, cfg)))
43
+ traj = vmap(jit(traj_fn))
44
+
45
+ obs, state = init(init_key)
46
+ state, (seq, action) = traj(state, init_key)
47
+
48
+ pb.utils.svg_fn(cfg, seq, action, "/Users/nobr/desk/s3/parabellum/sims.svg", debug=True)
49
+ # %% Anim
50
+ # for i in range(seq.pos.shape[0]): # sims
51
+ # for j in range(seq.pos.shape[2]): # units
52
+ # shots = [(kdx, coord) for kdx, coord in enumerate(action.pos[i, :, j]) if action.shoot[i, kdx, j]]
53
+ # print(shots)
54
+ # print(tree.map(jnp.shape, action))
55
+ # print(tree.map(jnp.shape, seq))
56
+ # pb.utils.svg_fn(cfg, seq.pos, action)
57
+ # pb.utils.gif_fn(cfg, seq)
@@ -0,0 +1,4 @@
1
+ from . import env, geo, types, utils
2
+ from .env import Env
3
+
4
+ __all__ = ["geo", "env", "types", "utils", "Env"]
@@ -0,0 +1,70 @@
1
+ # %% env.py
2
+ # parabellum env
3
+ # by: Noah Syrkis
4
+
5
+ # Imports
6
+ from functools import partial
7
+ from typing import Tuple
8
+
9
+ import jax.numpy as jnp
10
+ import jaxkd as jk
11
+ from jax import random
12
+ from jaxtyping import Array
13
+
14
+ from parabellum.types import Action, Config, Obs, State
15
+
16
+
17
+ # %% Dataclass
18
+ class Env:
19
+ def init(self, cfg: Config, rng: Array) -> Tuple[Obs, State]:
20
+ state = init_fn(cfg, rng) # without jit this takes forever
21
+ return obs_fn(cfg, state), state
22
+
23
+ def step(self, cfg: Config, rng: Array, state: State, action: Action) -> Tuple[Obs, State]:
24
+ state = step_fn(cfg, rng, state, action)
25
+ return obs_fn(cfg, state), state
26
+
27
+
28
+ # %% Functions
29
+ def init_fn(cfg: Config, rng: Array) -> State:
30
+ prob = jnp.ones((cfg.size, cfg.size)).at[cfg.map].set(0).flatten() # Set
31
+ flat = random.choice(rng, jnp.arange(prob.size), shape=(cfg.length,), p=prob, replace=True)
32
+ idxs = (flat // len(cfg.map), flat % len(cfg.map))
33
+ pos = jnp.float32(jnp.column_stack(idxs))
34
+ return State(pos=pos, hp=cfg.hp[cfg.types])
35
+
36
+
37
+ def obs_fn(cfg: Config, state: State) -> Obs: # return info about neighbors ---
38
+ idxs, dist = jk.extras.query_neighbors_pairwise(state.pos, state.pos, k=cfg.knn)
39
+ mask = dist < cfg.sight[cfg.types[idxs][:, 0]][..., None] # | (state.hp[idxs] > 0)
40
+ pos = (state.pos[idxs] - state.pos[:, None, ...]).at[:, 0, :].set(state.pos) * mask[..., None]
41
+ args = state.hp, cfg.types, cfg.teams, cfg.reach, cfg.sight, cfg.speed
42
+ hp, type, team, reach, sight, speed = map(lambda x: x[idxs] * mask, args)
43
+ return Obs(pos=pos, dist=dist, hp=hp, type=type, team=team, reach=reach, sight=sight, speed=speed, mask=mask)
44
+
45
+
46
+ def step_fn(cfg: Config, rng: Array, state: State, action: Action) -> State:
47
+ idx, norm = jk.extras.query_neighbors_pairwise(state.pos + action.pos, state.pos, k=2)
48
+ args = rng, cfg, state, action, idx, norm
49
+ return State(pos=partial(push_fn, cfg, rng, idx, norm)(move_fn(*args)), hp=blast_fn(*args)) # type: ignore
50
+
51
+
52
+ def move_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
53
+ speed = cfg.speed[cfg.types][..., None] # max speed of a unit (step size, really)
54
+ pos = state.pos + action.pos.clip(-speed, speed) * action.move[..., None] # new poss
55
+ mask = ((pos < 0).any(axis=-1) | ((pos >= cfg.size).any(axis=-1)) | (cfg.map[*jnp.int32(pos).T] > 0))[..., None]
56
+ return jnp.where(mask, state.pos, pos) # compute new position
57
+
58
+
59
+ def blast_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
60
+ dam = (cfg.dam[cfg.types] * action.cast)[..., None] * jnp.ones_like(idx)
61
+ return state.hp - jnp.zeros(cfg.length, dtype=jnp.int32).at[idx.flatten()].add(dam.flatten())
62
+
63
+
64
+ def push_fn(cfg: Config, rng: Array, idx: Array, norm: Array, pos: Array) -> Array:
65
+ return pos + random.normal(rng, pos.shape) * 0.1
66
+ # params need to be tweaked, and matched with unit size
67
+ pos_diff = pos[:, None, :] - pos[idx] # direction away from neighbors
68
+ mask = (norm < cfg.r[cfg.types][..., None]) & (norm > 0)
69
+ pos = pos + jnp.where(mask[..., None], pos_diff * cfg.force / (norm[..., None] + 1e-6), 0.0).sum(axis=1)
70
+ return pos + random.normal(rng, pos.shape) * 0.1
@@ -0,0 +1,213 @@
1
+ # %% geo.py
2
+ # script for geospatial level generation
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from rasterio import features, transform
7
+
8
+ # from jax import tree
9
+ from geopy.geocoders import Nominatim
10
+ from geopy.distance import distance
11
+ import contextily as cx
12
+ import jax.numpy as jnp
13
+ import cartopy.crs as ccrs
14
+ from jaxtyping import Array
15
+ from shapely import box
16
+ import osmnx as ox
17
+ import geopandas as gpd
18
+ from collections import namedtuple
19
+ from typing import Tuple
20
+ import matplotlib.pyplot as plt
21
+ from cachier import cachier
22
+ # from jax.scipy.signal import convolve
23
+ # from parabellum.types import Terrain
24
+
25
+ # %% Types
26
+ Coords = Tuple[float, float]
27
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
28
+
29
+ # %% Constants
30
+ provider = cx.providers.Stadia.StamenTerrain( # type: ignore
31
+ api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
32
+ )
33
+ provider["url"] = provider["url"] + "?api_key={api_key}"
34
+ tags = {
35
+ "building": True,
36
+ "water": True,
37
+ "highway": True,
38
+ "landuse": [
39
+ "grass",
40
+ "forest",
41
+ "flowerbed",
42
+ "greenfield",
43
+ "village_green",
44
+ "recreation_ground",
45
+ ],
46
+ "leisure": "garden",
47
+ } # "road": True}
48
+
49
+
50
+ # %% Coordinate function
51
+ def get_coordinates(place: str) -> Coords:
52
+ geolocator = Nominatim(user_agent="parabellum")
53
+ point = geolocator.geocode(place)
54
+ return point.latitude, point.longitude # type: ignore
55
+
56
+
57
+ def get_bbox(place: str, buffer) -> BBox:
58
+ """Get bounding box from place name in crs 4326."""
59
+ coords = get_coordinates(place)
60
+ north = distance(meters=buffer).destination(coords, bearing=0).latitude
61
+ south = distance(meters=buffer).destination(coords, bearing=180).latitude
62
+ east = distance(meters=buffer).destination(coords, bearing=90).longitude
63
+ west = distance(meters=buffer).destination(coords, bearing=270).longitude
64
+ return BBox(north, south, east, west) # type: ignore
65
+
66
+
67
+ def basemap_fn(bbox: BBox, gdf) -> Array:
68
+ fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
69
+ gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
70
+ cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
71
+ bbox = gdf.total_bounds
72
+ ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
73
+ plt.axis("off")
74
+ plt.tight_layout(pad=0)
75
+ fig.canvas.draw()
76
+ image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
77
+ plt.close(fig)
78
+ return image
79
+
80
+
81
+ @cachier()
82
+ def geography_fn(place, buffer):
83
+ bbox = get_bbox(place, buffer)
84
+ map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
85
+ gdf = gpd.GeoDataFrame(map_data)
86
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
87
+ raster = raster_fn(gdf, shape=(buffer, buffer))
88
+ # basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
89
+ # kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
90
+ trans = lambda x: jnp.bool(x) # jnp.rot90(x, 3) # noqa
91
+ terrain = trans(raster[0]) # Terrain(
92
+ # building=trans(raster[0]),
93
+ # water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
94
+ # forest=trans(jnp.logical_or(raster[3], raster[4])),
95
+ # basemap=basemap,
96
+ # )
97
+ # terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
98
+ return terrain
99
+
100
+
101
+ # =======
102
+ # terrain = tps.Terrain(building=trans(raster[0] - convolve(raster[0]*raster[2], kernel, mode='same')>0),
103
+ # water=trans(raster[1] - convolve(raster[1]*raster[2], kernel, mode='same')>0),
104
+ # forest=trans(jnp.logical_or(raster[3], raster[4])),
105
+ # basemap=basemap)
106
+ # return terrain, gdf
107
+ # >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
108
+
109
+
110
+ def raster_fn(gdf, shape) -> Array:
111
+ bbox = gdf.total_bounds
112
+ t = transform.from_bounds(*bbox, *shape) # type: ignore
113
+ raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
114
+ return raster
115
+
116
+
117
+ def feature_fn(t, feature, gdf, shape):
118
+ if feature not in gdf.columns:
119
+ return jnp.zeros(shape)
120
+ gdf = gdf[~gdf[feature].isna()]
121
+ raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
122
+ return raster
123
+
124
+
125
+ # %%
126
+ # def normalize(x):
127
+ # return (np.array(x) - m) / (M - m)
128
+
129
+
130
+ # def get_bridges(gdf):
131
+ # xmin, ymin, xmax, ymax = gdf.total_bounds
132
+ # m = np.array([xmin, ymin])
133
+ # M = np.array([xmax, ymax])
134
+
135
+ # bridges = {}
136
+ # for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
137
+ # if type(bridge["name"]) == str:
138
+ # bridges[idx[1]] = {
139
+ # "name": bridge["name"],
140
+ # "coords": normalize(
141
+ # [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
142
+ # ),
143
+ # }
144
+ # return bridges
145
+
146
+
147
+ """
148
+ # %%
149
+ if __name__ == "__main__":
150
+ place = "Thun, Switzerland"
151
+ <<<<<<< HEAD
152
+ terrain = geography_fn(place, 300)
153
+
154
+ =======
155
+ terrain, gdf = geography_fn(place, 300)
156
+
157
+ >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
158
+ fig, axes = plt.subplots(1, 5, figsize=(20, 20))
159
+ axes[0].imshow(jnp.rot90(terrain.building), cmap="gray")
160
+ axes[1].imshow(jnp.rot90(terrain.water), cmap="gray")
161
+ axes[2].imshow(jnp.rot90(terrain.forest), cmap="gray")
162
+ axes[3].imshow(jnp.rot90(terrain.building + terrain.water + terrain.forest))
163
+ axes[4].imshow(jnp.rot90(terrain.basemap))
164
+
165
+ # %%
166
+ W, H, _ = terrain.basemap.shape
167
+ bridges = get_bridges(gdf)
168
+
169
+ # %%
170
+ print("Bridges:")
171
+ for bridge in bridges.values():
172
+ x, y = int(bridge["coords"][0]*300), int(bridge["coords"][1]*300)
173
+ print(bridge["name"], f"at ({x}, {y})")
174
+
175
+ # %%
176
+ plt.subplots(figsize=(7,7))
177
+ plt.imshow(jnp.rot90(terrain.basemap))
178
+ X = [b["coords"][0]*W for b in bridges.values()]
179
+ Y = [(1-b["coords"][1])*H for b in bridges.values()]
180
+ plt.scatter(X, Y)
181
+ for i in range(len(X)):
182
+ x,y = int(X[i]), int(Y[i])
183
+ plt.text(x, y, str((int(x/W*300), int((1-(y/H))*300))))
184
+
185
+ # %%
186
+
187
+ # %% [raw]
188
+ # fig, ax = plt.subplots(figsize=(10, 10))
189
+ # gdf.plot(ax=ax, color='lightgray') # Plot all features
190
+ # bridges.plot(ax=ax, color='red') # Highlight bridges in red
191
+ # plt.show()
192
+
193
+ # %%
194
+
195
+ """
196
+
197
+ # BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
198
+
199
+
200
+ # def to_mercator(bbox: BBox) -> BBox:
201
+ # proj = ccrs.Mercator()
202
+ # west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
203
+ # east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
204
+ # return BBox(north=north, south=south, east=east, west=west)
205
+ #
206
+ #
207
+ # def to_platecarree(bbox: BBox) -> BBox:
208
+ # proj = ccrs.PlateCarree()
209
+ # west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
210
+ #
211
+ # east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
212
+ # return BBox(north=north, south=south, east=east, west=west)
213
+ #
@@ -0,0 +1,138 @@
1
+ # types.py
2
+ # parabellum types
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ from chex import dataclass
7
+ from jaxtyping import Array, Float32, Int
8
+ import jax.numpy as jnp
9
+ from parabellum.geo import geography_fn
10
+ from dataclasses import field
11
+
12
+
13
+ @dataclass
14
+ class Kind:
15
+ hp: int
16
+ dam: int
17
+ speed: int
18
+ reach: int
19
+ sight: int
20
+ blast: int
21
+ r: float
22
+
23
+
24
+ @dataclass
25
+ class Rules:
26
+ troop = Kind(hp=120, dam=15, speed=2, reach=4, sight=4, blast=1, r=1)
27
+ armor = Kind(hp=150, dam=12, speed=1, reach=8, sight=16, blast=3, r=2)
28
+ plane = Kind(hp=80, dam=20, speed=4, reach=16, sight=32, blast=2, r=2)
29
+ civil = Kind(hp=100, dam=0, speed=3, reach=5, sight=10, blast=1, r=2)
30
+ medic = Kind(hp=100, dam=-10, speed=3, reach=5, sight=10, blast=1, r=2)
31
+
32
+ def __post_init__(self):
33
+ self.hp = jnp.array((self.troop.hp, self.armor.hp, self.plane.hp, self.civil.hp, self.medic.hp))
34
+ self.dam = jnp.array((self.troop.dam, self.armor.dam, self.plane.dam, self.civil.dam, self.medic.dam))
35
+ self.r = jnp.array((self.troop.r, self.armor.r, self.plane.r, self.civil.r, self.medic.r))
36
+ self.speed = jnp.array(
37
+ (self.troop.speed, self.armor.speed, self.plane.speed, self.civil.speed, self.medic.speed)
38
+ )
39
+ self.reach = jnp.array(
40
+ (self.troop.reach, self.armor.reach, self.plane.reach, self.civil.reach, self.medic.reach)
41
+ )
42
+ self.sight = jnp.array(
43
+ (self.troop.sight, self.armor.sight, self.plane.sight, self.civil.sight, self.medic.sight)
44
+ )
45
+ self.blast = jnp.array(
46
+ (self.troop.blast, self.armor.blast, self.plane.blast, self.civil.blast, self.medic.blast)
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class Team:
52
+ troop: int = 1
53
+ armor: int = 0
54
+ plane: int = 0
55
+ civil: int = 0
56
+ medic: int = 0
57
+
58
+ def __post_init__(self):
59
+ self.length: int = self.troop + self.armor + self.plane + self.civil + self.medic
60
+ self.types: Array = jnp.repeat(
61
+ jnp.arange(5), jnp.array((self.troop, self.armor, self.plane, self.civil, self.medic))
62
+ )
63
+
64
+
65
+ # dataclasses
66
+ @dataclass
67
+ class State:
68
+ pos: Array
69
+ hp: Array
70
+ # target: Array
71
+
72
+
73
+ @dataclass
74
+ class Obs:
75
+ # idxs: Array
76
+ hp: Array
77
+ pos: Array
78
+ type: Array
79
+ team: Array
80
+ dist: Array
81
+ mask: Array
82
+ reach: Array
83
+ sight: Array
84
+ speed: Array
85
+
86
+ @property
87
+ def ally(self):
88
+ return (self.team == self.team[0]) & self.mask
89
+
90
+ @property
91
+ def enemy(self):
92
+ return (self.team != self.team[0]) & self.mask
93
+
94
+
95
+ @dataclass
96
+ class Action:
97
+ pos: Array
98
+ kind: Int[Array, "..."] # 0 = invalid, 1 = move, 2 = cast
99
+
100
+ @property
101
+ def invalid(self):
102
+ return self.kind == 0
103
+
104
+ @property
105
+ def move(self):
106
+ return self.kind == 1
107
+
108
+ @property
109
+ def cast(self): # cast bomb, bullet or medicin
110
+ return self.kind == 2
111
+
112
+
113
+ @dataclass
114
+ class Config: # Remove frozen=True for now
115
+ steps: int = 123
116
+ place: str = "Palazzo della Civiltà Italiana, Rome, Italy"
117
+ force: float = 0.5
118
+ sims: int = 2
119
+ size: int = 64
120
+ knn: int = 2
121
+ blu: Team = field(default_factory=lambda: Team())
122
+ red: Team = field(default_factory=lambda: Team())
123
+ rules: Rules = field(default_factory=lambda: Rules())
124
+
125
+ def __post_init__(self):
126
+ # Pre-compute everything once
127
+ self.types: Array = jnp.concat((self.blu.types, self.red.types))
128
+ self.teams: Array = jnp.repeat(jnp.arange(2), jnp.array((self.blu.length, self.red.length)))
129
+ self.map: Array = geography_fn(self.place, self.size) # Computed once here
130
+ self.hp: Array = self.rules.hp
131
+ self.dam: Array = self.rules.dam
132
+ self.r: Array = self.rules.r
133
+ self.speed: Array = self.rules.speed
134
+ self.reach: Array = self.rules.reach
135
+ self.sight: Array = self.rules.sight
136
+ self.blast: Array = self.rules.blast
137
+ self.length: int = self.blu.length + self.red.length
138
+ self.root: Array = jnp.int32(jnp.sqrt(self.length))
@@ -0,0 +1,59 @@
1
+ # %% utils.py
2
+ # parabellum ut
3
+
4
+
5
+ # Imports
6
+ import esch
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from einops import rearrange, repeat
10
+ from jax import tree
11
+ from PIL import Image
12
+ from parabellum.types import Config
13
+
14
+ # Twilight colors (used in neurocope)
15
+ red = "#EA344A"
16
+ blue = "#2B60F6"
17
+
18
+
19
+ # %% Plotting
20
+ def gif_fn(cfg: Config, seq, scale=4): # animate positions TODO: remove dead units
21
+ pos = seq.pos.astype(int)
22
+ cord = jnp.concat((jnp.arange(pos.shape[0]).repeat(pos.shape[1])[..., None], pos.reshape(-1, 2)), axis=1).T
23
+ idxs = cord[:, seq.hp.flatten().astype(bool) > 0]
24
+ imgs = 1 - np.array(repeat(cfg.map, "... -> a ...", a=len(pos)).at[*idxs].set(1))
25
+ imgs = [Image.fromarray(img).resize(np.array(img.shape[:2]) * scale, Image.NEAREST) for img in imgs * 255] # type: ignore
26
+ imgs[0].save("/Users/nobr/desk/s3/btc2sim/sims.gif", save_all=True, append_images=imgs[1:], duration=10, loop=0)
27
+
28
+
29
+ def svg_fn(cfg: Config, seq, action, fname, targets=None, fps=2, debug=False):
30
+ # set up and background
31
+ e = esch.Drawing(h=cfg.size, w=cfg.size, row=1, col=seq.pos.shape[0], debug=debug, pad=10)
32
+ esch.grid_fn(repeat(np.array(cfg.map, dtype=float), f"... -> {seq.pos.shape[0]} ...") * 0.5, e, shape="square")
33
+
34
+ # loop thorugh teams
35
+ for i in jnp.unique(cfg.teams): # c#fg.teams.unique():
36
+ col = "red" if i == 1 else "blue"
37
+
38
+ # loop through types
39
+ for j in jnp.unique(cfg.types):
40
+ mask = (cfg.teams == i) & (cfg.types == j)
41
+ size, blast = float(cfg.rules.r[j]), float(cfg.rules.blast[j])
42
+ subset = np.array(rearrange(seq.pos, "a b c d -> a c d b"), dtype=float)[:, mask]
43
+ # print(tree.map(jnp.shape, action), mask.shape)
44
+ sub_action = tree.map(lambda x: x[:, :, mask], action)
45
+ # print(tree.map(jnp.shape, sub_action))
46
+ esch.sims_fn(e, subset, action=sub_action, fps=fps, col=col, stroke=col, size=size, blast=blast)
47
+
48
+ if debug:
49
+ sight, reach = float(cfg.rules.sight[j]), float(cfg.rules.reach[j])
50
+ esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=reach, stroke="grey")
51
+ esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=sight, stroke="yellow")
52
+
53
+ if targets is not None:
54
+ pos = np.array(repeat(targets, f"... -> {seq.pos.shape[0]} ..."))
55
+ arr = np.ones(pos.shape[:-1])
56
+ esch.mesh_fn(e, pos, arr, shape="square", col="purple")
57
+
58
+ # save
59
+ e.dwg.saveas(fname)
@@ -0,0 +1,63 @@
1
+ [project]
2
+ name = "parabellum"
3
+ version = "0.0.73"
4
+ description = "Parabellum environment for parallel warfare simulation"
5
+ authors = [{ name = "Noah Syrkis", email = "desk@syrkis.com" }]
6
+ requires-python = ">=3.11,<3.12"
7
+ readme = "README.md"
8
+ dependencies = [
9
+ "jupyterlab>=4.2.2,<5",
10
+ "poetry>=1.8.3,<2",
11
+ "tqdm>=4.66.4,<5",
12
+ "geopy>=2.4.1,<3",
13
+ "osmnx==2.0.0b0",
14
+ "rasterio>=1.3.10,<2",
15
+ "ipykernel>=6.29.5,<7",
16
+ "folium>=0.17.0,<0.18",
17
+ "pandas>=2.2.2,<3",
18
+ "contextily>=1.6.0,<2",
19
+ "einops>=0.8.0,<0.9",
20
+ "jaxtyping>=0.2.33,<0.3",
21
+ "cartopy>=0.23.0,<0.24",
22
+ "stadiamaps>=3.2.1,<4",
23
+ "cachier>=3.1.2,<4",
24
+ "jax>=0.6.0,<0.7",
25
+ "gymnax>=0.0.8,<0.0.9",
26
+ "evosax>=0.1.6,<0.2",
27
+ "distrax>=0.1.5,<0.2",
28
+ "optax>=0.2.4,<0.3",
29
+ "flax>=0.10.4,<0.11",
30
+ "brax>=0.12.1,<0.13",
31
+ "wandb>=0.19.7,<0.20",
32
+ "flashbax>=0.1.2,<0.2",
33
+ "navix>=0.7.0,<0.8",
34
+ "omegaconf>=2.3.0,<3",
35
+ "jax-tqdm>=0.3.1,<0.4",
36
+ "ipython>=8.36.0",
37
+ "notebook>=7.4.2",
38
+ "equinox>=0.12.2",
39
+ "tensorboard>=2.19.0",
40
+ "tensorflow>=2.19.0",
41
+ "xprof>=2.20.0",
42
+ "jaxkd>=0.1.0",
43
+ "numpy>=2",
44
+ ]
45
+
46
+ [dependency-groups]
47
+ dev = ["esch"]
48
+
49
+ [tool.uv]
50
+
51
+ [tool.uv.sources]
52
+ esch = { path = "../../esch" }
53
+
54
+ [build-system]
55
+ requires = ["hatchling"]
56
+ build-backend = "hatchling.build"
57
+
58
+ [tool.pyright]
59
+ venvPath = "."
60
+ venv = ".venv"
61
+
62
+ [tool.ruff]
63
+ line-length = 120