parabellum 0.0.0__tar.gz → 0.0.73__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum-0.0.73/.gitignore +16 -0
- parabellum-0.0.73/PKG-INFO +46 -0
- parabellum-0.0.73/README.md +3 -0
- parabellum-0.0.73/main.py +57 -0
- parabellum-0.0.73/parabellum/__init__.py +4 -0
- parabellum-0.0.73/parabellum/env.py +70 -0
- parabellum-0.0.73/parabellum/geo.py +213 -0
- parabellum-0.0.73/parabellum/types.py +138 -0
- parabellum-0.0.73/parabellum/utils.py +59 -0
- parabellum-0.0.73/pyproject.toml +63 -0
- parabellum-0.0.73/uv.lock +4426 -0
- parabellum-0.0.0/PKG-INFO +0 -55
- parabellum-0.0.0/README.md +0 -31
- parabellum-0.0.0/parabellum/.ipynb_checkpoints/__init__-checkpoint.py +0 -4
- parabellum-0.0.0/parabellum/.ipynb_checkpoints/env-checkpoint.py +0 -296
- parabellum-0.0.0/parabellum/.ipynb_checkpoints/vis-checkpoint.py +0 -230
- parabellum-0.0.0/parabellum/__init__.py +0 -4
- parabellum-0.0.0/parabellum/env.py +0 -296
- parabellum-0.0.0/parabellum/vis.py +0 -230
- parabellum-0.0.0/pyproject.toml +0 -29
@@ -0,0 +1,46 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.0.73
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Author-email: Noah Syrkis <desk@syrkis.com>
|
6
|
+
Requires-Python: <3.12,>=3.11
|
7
|
+
Requires-Dist: brax<0.13,>=0.12.1
|
8
|
+
Requires-Dist: cachier<4,>=3.1.2
|
9
|
+
Requires-Dist: cartopy<0.24,>=0.23.0
|
10
|
+
Requires-Dist: contextily<2,>=1.6.0
|
11
|
+
Requires-Dist: distrax<0.2,>=0.1.5
|
12
|
+
Requires-Dist: einops<0.9,>=0.8.0
|
13
|
+
Requires-Dist: equinox>=0.12.2
|
14
|
+
Requires-Dist: evosax<0.2,>=0.1.6
|
15
|
+
Requires-Dist: flashbax<0.2,>=0.1.2
|
16
|
+
Requires-Dist: flax<0.11,>=0.10.4
|
17
|
+
Requires-Dist: folium<0.18,>=0.17.0
|
18
|
+
Requires-Dist: geopy<3,>=2.4.1
|
19
|
+
Requires-Dist: gymnax<0.0.9,>=0.0.8
|
20
|
+
Requires-Dist: ipykernel<7,>=6.29.5
|
21
|
+
Requires-Dist: ipython>=8.36.0
|
22
|
+
Requires-Dist: jax-tqdm<0.4,>=0.3.1
|
23
|
+
Requires-Dist: jax<0.7,>=0.6.0
|
24
|
+
Requires-Dist: jaxkd>=0.1.0
|
25
|
+
Requires-Dist: jaxtyping<0.3,>=0.2.33
|
26
|
+
Requires-Dist: jupyterlab<5,>=4.2.2
|
27
|
+
Requires-Dist: navix<0.8,>=0.7.0
|
28
|
+
Requires-Dist: notebook>=7.4.2
|
29
|
+
Requires-Dist: numpy>=2
|
30
|
+
Requires-Dist: omegaconf<3,>=2.3.0
|
31
|
+
Requires-Dist: optax<0.3,>=0.2.4
|
32
|
+
Requires-Dist: osmnx==2.0.0b0
|
33
|
+
Requires-Dist: pandas<3,>=2.2.2
|
34
|
+
Requires-Dist: poetry<2,>=1.8.3
|
35
|
+
Requires-Dist: rasterio<2,>=1.3.10
|
36
|
+
Requires-Dist: stadiamaps<4,>=3.2.1
|
37
|
+
Requires-Dist: tensorboard>=2.19.0
|
38
|
+
Requires-Dist: tensorflow>=2.19.0
|
39
|
+
Requires-Dist: tqdm<5,>=4.66.4
|
40
|
+
Requires-Dist: wandb<0.20,>=0.19.7
|
41
|
+
Requires-Dist: xprof>=2.20.0
|
42
|
+
Description-Content-Type: text/markdown
|
43
|
+
|
44
|
+
# Parabellum
|
45
|
+
|
46
|
+
TODO: switch to red and blue team semantics (not enemy and ally)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# %% main.py
|
2
|
+
# parabellum main
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# Imports ###################################################################
|
6
|
+
from functools import partial
|
7
|
+
from typing import Tuple
|
8
|
+
import esch
|
9
|
+
|
10
|
+
from einops import repeat, rearrange
|
11
|
+
import numpy as np
|
12
|
+
from jax import jit, lax, random, tree, vmap
|
13
|
+
from jax import numpy as jnp
|
14
|
+
from jaxtyping import Array
|
15
|
+
|
16
|
+
import parabellum as pb
|
17
|
+
from parabellum.types import Action, Config, State
|
18
|
+
|
19
|
+
|
20
|
+
# %% Functions
|
21
|
+
def action_fn(cfg: Config, rng: Array) -> Action:
|
22
|
+
pos = random.uniform(rng, (cfg.length, 2), minval=-1, maxval=1) * cfg.rules.reach[cfg.types][..., None]
|
23
|
+
kind = random.randint(rng, (cfg.length,), minval=0, maxval=3)
|
24
|
+
return Action(pos=pos, kind=kind)
|
25
|
+
|
26
|
+
|
27
|
+
def step_fn(state: State, rng: Array) -> Tuple[State, Tuple[State, Action]]:
|
28
|
+
action = action_fn(cfg, rng)
|
29
|
+
obs, state = env.step(cfg, rng, state, action)
|
30
|
+
return state, (state, action)
|
31
|
+
|
32
|
+
|
33
|
+
def traj_fn(state, rng) -> Tuple[State, Tuple[State, Action]]:
|
34
|
+
rngs = random.split(rng, cfg.steps)
|
35
|
+
return lax.scan(step_fn, state, rngs)
|
36
|
+
|
37
|
+
|
38
|
+
# %% Main
|
39
|
+
env, cfg = pb.env.Env(), Config()
|
40
|
+
init_key, traj_key = random.split(random.PRNGKey(0), (2, cfg.sims))
|
41
|
+
|
42
|
+
init = vmap(jit(partial(env.init, cfg)))
|
43
|
+
traj = vmap(jit(traj_fn))
|
44
|
+
|
45
|
+
obs, state = init(init_key)
|
46
|
+
state, (seq, action) = traj(state, init_key)
|
47
|
+
|
48
|
+
pb.utils.svg_fn(cfg, seq, action, "/Users/nobr/desk/s3/parabellum/sims.svg", debug=True)
|
49
|
+
# %% Anim
|
50
|
+
# for i in range(seq.pos.shape[0]): # sims
|
51
|
+
# for j in range(seq.pos.shape[2]): # units
|
52
|
+
# shots = [(kdx, coord) for kdx, coord in enumerate(action.pos[i, :, j]) if action.shoot[i, kdx, j]]
|
53
|
+
# print(shots)
|
54
|
+
# print(tree.map(jnp.shape, action))
|
55
|
+
# print(tree.map(jnp.shape, seq))
|
56
|
+
# pb.utils.svg_fn(cfg, seq.pos, action)
|
57
|
+
# pb.utils.gif_fn(cfg, seq)
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# %% env.py
|
2
|
+
# parabellum env
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# Imports
|
6
|
+
from functools import partial
|
7
|
+
from typing import Tuple
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import jaxkd as jk
|
11
|
+
from jax import random
|
12
|
+
from jaxtyping import Array
|
13
|
+
|
14
|
+
from parabellum.types import Action, Config, Obs, State
|
15
|
+
|
16
|
+
|
17
|
+
# %% Dataclass
|
18
|
+
class Env:
|
19
|
+
def init(self, cfg: Config, rng: Array) -> Tuple[Obs, State]:
|
20
|
+
state = init_fn(cfg, rng) # without jit this takes forever
|
21
|
+
return obs_fn(cfg, state), state
|
22
|
+
|
23
|
+
def step(self, cfg: Config, rng: Array, state: State, action: Action) -> Tuple[Obs, State]:
|
24
|
+
state = step_fn(cfg, rng, state, action)
|
25
|
+
return obs_fn(cfg, state), state
|
26
|
+
|
27
|
+
|
28
|
+
# %% Functions
|
29
|
+
def init_fn(cfg: Config, rng: Array) -> State:
|
30
|
+
prob = jnp.ones((cfg.size, cfg.size)).at[cfg.map].set(0).flatten() # Set
|
31
|
+
flat = random.choice(rng, jnp.arange(prob.size), shape=(cfg.length,), p=prob, replace=True)
|
32
|
+
idxs = (flat // len(cfg.map), flat % len(cfg.map))
|
33
|
+
pos = jnp.float32(jnp.column_stack(idxs))
|
34
|
+
return State(pos=pos, hp=cfg.hp[cfg.types])
|
35
|
+
|
36
|
+
|
37
|
+
def obs_fn(cfg: Config, state: State) -> Obs: # return info about neighbors ---
|
38
|
+
idxs, dist = jk.extras.query_neighbors_pairwise(state.pos, state.pos, k=cfg.knn)
|
39
|
+
mask = dist < cfg.sight[cfg.types[idxs][:, 0]][..., None] # | (state.hp[idxs] > 0)
|
40
|
+
pos = (state.pos[idxs] - state.pos[:, None, ...]).at[:, 0, :].set(state.pos) * mask[..., None]
|
41
|
+
args = state.hp, cfg.types, cfg.teams, cfg.reach, cfg.sight, cfg.speed
|
42
|
+
hp, type, team, reach, sight, speed = map(lambda x: x[idxs] * mask, args)
|
43
|
+
return Obs(pos=pos, dist=dist, hp=hp, type=type, team=team, reach=reach, sight=sight, speed=speed, mask=mask)
|
44
|
+
|
45
|
+
|
46
|
+
def step_fn(cfg: Config, rng: Array, state: State, action: Action) -> State:
|
47
|
+
idx, norm = jk.extras.query_neighbors_pairwise(state.pos + action.pos, state.pos, k=2)
|
48
|
+
args = rng, cfg, state, action, idx, norm
|
49
|
+
return State(pos=partial(push_fn, cfg, rng, idx, norm)(move_fn(*args)), hp=blast_fn(*args)) # type: ignore
|
50
|
+
|
51
|
+
|
52
|
+
def move_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
|
53
|
+
speed = cfg.speed[cfg.types][..., None] # max speed of a unit (step size, really)
|
54
|
+
pos = state.pos + action.pos.clip(-speed, speed) * action.move[..., None] # new poss
|
55
|
+
mask = ((pos < 0).any(axis=-1) | ((pos >= cfg.size).any(axis=-1)) | (cfg.map[*jnp.int32(pos).T] > 0))[..., None]
|
56
|
+
return jnp.where(mask, state.pos, pos) # compute new position
|
57
|
+
|
58
|
+
|
59
|
+
def blast_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
|
60
|
+
dam = (cfg.dam[cfg.types] * action.cast)[..., None] * jnp.ones_like(idx)
|
61
|
+
return state.hp - jnp.zeros(cfg.length, dtype=jnp.int32).at[idx.flatten()].add(dam.flatten())
|
62
|
+
|
63
|
+
|
64
|
+
def push_fn(cfg: Config, rng: Array, idx: Array, norm: Array, pos: Array) -> Array:
|
65
|
+
return pos + random.normal(rng, pos.shape) * 0.1
|
66
|
+
# params need to be tweaked, and matched with unit size
|
67
|
+
pos_diff = pos[:, None, :] - pos[idx] # direction away from neighbors
|
68
|
+
mask = (norm < cfg.r[cfg.types][..., None]) & (norm > 0)
|
69
|
+
pos = pos + jnp.where(mask[..., None], pos_diff * cfg.force / (norm[..., None] + 1e-6), 0.0).sum(axis=1)
|
70
|
+
return pos + random.normal(rng, pos.shape) * 0.1
|
@@ -0,0 +1,213 @@
|
|
1
|
+
# %% geo.py
|
2
|
+
# script for geospatial level generation
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from rasterio import features, transform
|
7
|
+
|
8
|
+
# from jax import tree
|
9
|
+
from geopy.geocoders import Nominatim
|
10
|
+
from geopy.distance import distance
|
11
|
+
import contextily as cx
|
12
|
+
import jax.numpy as jnp
|
13
|
+
import cartopy.crs as ccrs
|
14
|
+
from jaxtyping import Array
|
15
|
+
from shapely import box
|
16
|
+
import osmnx as ox
|
17
|
+
import geopandas as gpd
|
18
|
+
from collections import namedtuple
|
19
|
+
from typing import Tuple
|
20
|
+
import matplotlib.pyplot as plt
|
21
|
+
from cachier import cachier
|
22
|
+
# from jax.scipy.signal import convolve
|
23
|
+
# from parabellum.types import Terrain
|
24
|
+
|
25
|
+
# %% Types
|
26
|
+
Coords = Tuple[float, float]
|
27
|
+
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
28
|
+
|
29
|
+
# %% Constants
|
30
|
+
provider = cx.providers.Stadia.StamenTerrain( # type: ignore
|
31
|
+
api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
|
32
|
+
)
|
33
|
+
provider["url"] = provider["url"] + "?api_key={api_key}"
|
34
|
+
tags = {
|
35
|
+
"building": True,
|
36
|
+
"water": True,
|
37
|
+
"highway": True,
|
38
|
+
"landuse": [
|
39
|
+
"grass",
|
40
|
+
"forest",
|
41
|
+
"flowerbed",
|
42
|
+
"greenfield",
|
43
|
+
"village_green",
|
44
|
+
"recreation_ground",
|
45
|
+
],
|
46
|
+
"leisure": "garden",
|
47
|
+
} # "road": True}
|
48
|
+
|
49
|
+
|
50
|
+
# %% Coordinate function
|
51
|
+
def get_coordinates(place: str) -> Coords:
|
52
|
+
geolocator = Nominatim(user_agent="parabellum")
|
53
|
+
point = geolocator.geocode(place)
|
54
|
+
return point.latitude, point.longitude # type: ignore
|
55
|
+
|
56
|
+
|
57
|
+
def get_bbox(place: str, buffer) -> BBox:
|
58
|
+
"""Get bounding box from place name in crs 4326."""
|
59
|
+
coords = get_coordinates(place)
|
60
|
+
north = distance(meters=buffer).destination(coords, bearing=0).latitude
|
61
|
+
south = distance(meters=buffer).destination(coords, bearing=180).latitude
|
62
|
+
east = distance(meters=buffer).destination(coords, bearing=90).longitude
|
63
|
+
west = distance(meters=buffer).destination(coords, bearing=270).longitude
|
64
|
+
return BBox(north, south, east, west) # type: ignore
|
65
|
+
|
66
|
+
|
67
|
+
def basemap_fn(bbox: BBox, gdf) -> Array:
|
68
|
+
fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
|
69
|
+
gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
|
70
|
+
cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
|
71
|
+
bbox = gdf.total_bounds
|
72
|
+
ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
|
73
|
+
plt.axis("off")
|
74
|
+
plt.tight_layout(pad=0)
|
75
|
+
fig.canvas.draw()
|
76
|
+
image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
|
77
|
+
plt.close(fig)
|
78
|
+
return image
|
79
|
+
|
80
|
+
|
81
|
+
@cachier()
|
82
|
+
def geography_fn(place, buffer):
|
83
|
+
bbox = get_bbox(place, buffer)
|
84
|
+
map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
|
85
|
+
gdf = gpd.GeoDataFrame(map_data)
|
86
|
+
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
|
87
|
+
raster = raster_fn(gdf, shape=(buffer, buffer))
|
88
|
+
# basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
|
89
|
+
# kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
90
|
+
trans = lambda x: jnp.bool(x) # jnp.rot90(x, 3) # noqa
|
91
|
+
terrain = trans(raster[0]) # Terrain(
|
92
|
+
# building=trans(raster[0]),
|
93
|
+
# water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
|
94
|
+
# forest=trans(jnp.logical_or(raster[3], raster[4])),
|
95
|
+
# basemap=basemap,
|
96
|
+
# )
|
97
|
+
# terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
|
98
|
+
return terrain
|
99
|
+
|
100
|
+
|
101
|
+
# =======
|
102
|
+
# terrain = tps.Terrain(building=trans(raster[0] - convolve(raster[0]*raster[2], kernel, mode='same')>0),
|
103
|
+
# water=trans(raster[1] - convolve(raster[1]*raster[2], kernel, mode='same')>0),
|
104
|
+
# forest=trans(jnp.logical_or(raster[3], raster[4])),
|
105
|
+
# basemap=basemap)
|
106
|
+
# return terrain, gdf
|
107
|
+
# >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
|
108
|
+
|
109
|
+
|
110
|
+
def raster_fn(gdf, shape) -> Array:
|
111
|
+
bbox = gdf.total_bounds
|
112
|
+
t = transform.from_bounds(*bbox, *shape) # type: ignore
|
113
|
+
raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
|
114
|
+
return raster
|
115
|
+
|
116
|
+
|
117
|
+
def feature_fn(t, feature, gdf, shape):
|
118
|
+
if feature not in gdf.columns:
|
119
|
+
return jnp.zeros(shape)
|
120
|
+
gdf = gdf[~gdf[feature].isna()]
|
121
|
+
raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
|
122
|
+
return raster
|
123
|
+
|
124
|
+
|
125
|
+
# %%
|
126
|
+
# def normalize(x):
|
127
|
+
# return (np.array(x) - m) / (M - m)
|
128
|
+
|
129
|
+
|
130
|
+
# def get_bridges(gdf):
|
131
|
+
# xmin, ymin, xmax, ymax = gdf.total_bounds
|
132
|
+
# m = np.array([xmin, ymin])
|
133
|
+
# M = np.array([xmax, ymax])
|
134
|
+
|
135
|
+
# bridges = {}
|
136
|
+
# for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
|
137
|
+
# if type(bridge["name"]) == str:
|
138
|
+
# bridges[idx[1]] = {
|
139
|
+
# "name": bridge["name"],
|
140
|
+
# "coords": normalize(
|
141
|
+
# [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
|
142
|
+
# ),
|
143
|
+
# }
|
144
|
+
# return bridges
|
145
|
+
|
146
|
+
|
147
|
+
"""
|
148
|
+
# %%
|
149
|
+
if __name__ == "__main__":
|
150
|
+
place = "Thun, Switzerland"
|
151
|
+
<<<<<<< HEAD
|
152
|
+
terrain = geography_fn(place, 300)
|
153
|
+
|
154
|
+
=======
|
155
|
+
terrain, gdf = geography_fn(place, 300)
|
156
|
+
|
157
|
+
>>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
|
158
|
+
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
|
159
|
+
axes[0].imshow(jnp.rot90(terrain.building), cmap="gray")
|
160
|
+
axes[1].imshow(jnp.rot90(terrain.water), cmap="gray")
|
161
|
+
axes[2].imshow(jnp.rot90(terrain.forest), cmap="gray")
|
162
|
+
axes[3].imshow(jnp.rot90(terrain.building + terrain.water + terrain.forest))
|
163
|
+
axes[4].imshow(jnp.rot90(terrain.basemap))
|
164
|
+
|
165
|
+
# %%
|
166
|
+
W, H, _ = terrain.basemap.shape
|
167
|
+
bridges = get_bridges(gdf)
|
168
|
+
|
169
|
+
# %%
|
170
|
+
print("Bridges:")
|
171
|
+
for bridge in bridges.values():
|
172
|
+
x, y = int(bridge["coords"][0]*300), int(bridge["coords"][1]*300)
|
173
|
+
print(bridge["name"], f"at ({x}, {y})")
|
174
|
+
|
175
|
+
# %%
|
176
|
+
plt.subplots(figsize=(7,7))
|
177
|
+
plt.imshow(jnp.rot90(terrain.basemap))
|
178
|
+
X = [b["coords"][0]*W for b in bridges.values()]
|
179
|
+
Y = [(1-b["coords"][1])*H for b in bridges.values()]
|
180
|
+
plt.scatter(X, Y)
|
181
|
+
for i in range(len(X)):
|
182
|
+
x,y = int(X[i]), int(Y[i])
|
183
|
+
plt.text(x, y, str((int(x/W*300), int((1-(y/H))*300))))
|
184
|
+
|
185
|
+
# %%
|
186
|
+
|
187
|
+
# %% [raw]
|
188
|
+
# fig, ax = plt.subplots(figsize=(10, 10))
|
189
|
+
# gdf.plot(ax=ax, color='lightgray') # Plot all features
|
190
|
+
# bridges.plot(ax=ax, color='red') # Highlight bridges in red
|
191
|
+
# plt.show()
|
192
|
+
|
193
|
+
# %%
|
194
|
+
|
195
|
+
"""
|
196
|
+
|
197
|
+
# BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
198
|
+
|
199
|
+
|
200
|
+
# def to_mercator(bbox: BBox) -> BBox:
|
201
|
+
# proj = ccrs.Mercator()
|
202
|
+
# west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
|
203
|
+
# east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
|
204
|
+
# return BBox(north=north, south=south, east=east, west=west)
|
205
|
+
#
|
206
|
+
#
|
207
|
+
# def to_platecarree(bbox: BBox) -> BBox:
|
208
|
+
# proj = ccrs.PlateCarree()
|
209
|
+
# west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
|
210
|
+
#
|
211
|
+
# east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
|
212
|
+
# return BBox(north=north, south=south, east=east, west=west)
|
213
|
+
#
|
@@ -0,0 +1,138 @@
|
|
1
|
+
# types.py
|
2
|
+
# parabellum types
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
from chex import dataclass
|
7
|
+
from jaxtyping import Array, Float32, Int
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from parabellum.geo import geography_fn
|
10
|
+
from dataclasses import field
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class Kind:
|
15
|
+
hp: int
|
16
|
+
dam: int
|
17
|
+
speed: int
|
18
|
+
reach: int
|
19
|
+
sight: int
|
20
|
+
blast: int
|
21
|
+
r: float
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class Rules:
|
26
|
+
troop = Kind(hp=120, dam=15, speed=2, reach=4, sight=4, blast=1, r=1)
|
27
|
+
armor = Kind(hp=150, dam=12, speed=1, reach=8, sight=16, blast=3, r=2)
|
28
|
+
plane = Kind(hp=80, dam=20, speed=4, reach=16, sight=32, blast=2, r=2)
|
29
|
+
civil = Kind(hp=100, dam=0, speed=3, reach=5, sight=10, blast=1, r=2)
|
30
|
+
medic = Kind(hp=100, dam=-10, speed=3, reach=5, sight=10, blast=1, r=2)
|
31
|
+
|
32
|
+
def __post_init__(self):
|
33
|
+
self.hp = jnp.array((self.troop.hp, self.armor.hp, self.plane.hp, self.civil.hp, self.medic.hp))
|
34
|
+
self.dam = jnp.array((self.troop.dam, self.armor.dam, self.plane.dam, self.civil.dam, self.medic.dam))
|
35
|
+
self.r = jnp.array((self.troop.r, self.armor.r, self.plane.r, self.civil.r, self.medic.r))
|
36
|
+
self.speed = jnp.array(
|
37
|
+
(self.troop.speed, self.armor.speed, self.plane.speed, self.civil.speed, self.medic.speed)
|
38
|
+
)
|
39
|
+
self.reach = jnp.array(
|
40
|
+
(self.troop.reach, self.armor.reach, self.plane.reach, self.civil.reach, self.medic.reach)
|
41
|
+
)
|
42
|
+
self.sight = jnp.array(
|
43
|
+
(self.troop.sight, self.armor.sight, self.plane.sight, self.civil.sight, self.medic.sight)
|
44
|
+
)
|
45
|
+
self.blast = jnp.array(
|
46
|
+
(self.troop.blast, self.armor.blast, self.plane.blast, self.civil.blast, self.medic.blast)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class Team:
|
52
|
+
troop: int = 1
|
53
|
+
armor: int = 0
|
54
|
+
plane: int = 0
|
55
|
+
civil: int = 0
|
56
|
+
medic: int = 0
|
57
|
+
|
58
|
+
def __post_init__(self):
|
59
|
+
self.length: int = self.troop + self.armor + self.plane + self.civil + self.medic
|
60
|
+
self.types: Array = jnp.repeat(
|
61
|
+
jnp.arange(5), jnp.array((self.troop, self.armor, self.plane, self.civil, self.medic))
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
# dataclasses
|
66
|
+
@dataclass
|
67
|
+
class State:
|
68
|
+
pos: Array
|
69
|
+
hp: Array
|
70
|
+
# target: Array
|
71
|
+
|
72
|
+
|
73
|
+
@dataclass
|
74
|
+
class Obs:
|
75
|
+
# idxs: Array
|
76
|
+
hp: Array
|
77
|
+
pos: Array
|
78
|
+
type: Array
|
79
|
+
team: Array
|
80
|
+
dist: Array
|
81
|
+
mask: Array
|
82
|
+
reach: Array
|
83
|
+
sight: Array
|
84
|
+
speed: Array
|
85
|
+
|
86
|
+
@property
|
87
|
+
def ally(self):
|
88
|
+
return (self.team == self.team[0]) & self.mask
|
89
|
+
|
90
|
+
@property
|
91
|
+
def enemy(self):
|
92
|
+
return (self.team != self.team[0]) & self.mask
|
93
|
+
|
94
|
+
|
95
|
+
@dataclass
|
96
|
+
class Action:
|
97
|
+
pos: Array
|
98
|
+
kind: Int[Array, "..."] # 0 = invalid, 1 = move, 2 = cast
|
99
|
+
|
100
|
+
@property
|
101
|
+
def invalid(self):
|
102
|
+
return self.kind == 0
|
103
|
+
|
104
|
+
@property
|
105
|
+
def move(self):
|
106
|
+
return self.kind == 1
|
107
|
+
|
108
|
+
@property
|
109
|
+
def cast(self): # cast bomb, bullet or medicin
|
110
|
+
return self.kind == 2
|
111
|
+
|
112
|
+
|
113
|
+
@dataclass
|
114
|
+
class Config: # Remove frozen=True for now
|
115
|
+
steps: int = 123
|
116
|
+
place: str = "Palazzo della Civiltà Italiana, Rome, Italy"
|
117
|
+
force: float = 0.5
|
118
|
+
sims: int = 2
|
119
|
+
size: int = 64
|
120
|
+
knn: int = 2
|
121
|
+
blu: Team = field(default_factory=lambda: Team())
|
122
|
+
red: Team = field(default_factory=lambda: Team())
|
123
|
+
rules: Rules = field(default_factory=lambda: Rules())
|
124
|
+
|
125
|
+
def __post_init__(self):
|
126
|
+
# Pre-compute everything once
|
127
|
+
self.types: Array = jnp.concat((self.blu.types, self.red.types))
|
128
|
+
self.teams: Array = jnp.repeat(jnp.arange(2), jnp.array((self.blu.length, self.red.length)))
|
129
|
+
self.map: Array = geography_fn(self.place, self.size) # Computed once here
|
130
|
+
self.hp: Array = self.rules.hp
|
131
|
+
self.dam: Array = self.rules.dam
|
132
|
+
self.r: Array = self.rules.r
|
133
|
+
self.speed: Array = self.rules.speed
|
134
|
+
self.reach: Array = self.rules.reach
|
135
|
+
self.sight: Array = self.rules.sight
|
136
|
+
self.blast: Array = self.rules.blast
|
137
|
+
self.length: int = self.blu.length + self.red.length
|
138
|
+
self.root: Array = jnp.int32(jnp.sqrt(self.length))
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# %% utils.py
|
2
|
+
# parabellum ut
|
3
|
+
|
4
|
+
|
5
|
+
# Imports
|
6
|
+
import esch
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
from einops import rearrange, repeat
|
10
|
+
from jax import tree
|
11
|
+
from PIL import Image
|
12
|
+
from parabellum.types import Config
|
13
|
+
|
14
|
+
# Twilight colors (used in neurocope)
|
15
|
+
red = "#EA344A"
|
16
|
+
blue = "#2B60F6"
|
17
|
+
|
18
|
+
|
19
|
+
# %% Plotting
|
20
|
+
def gif_fn(cfg: Config, seq, scale=4): # animate positions TODO: remove dead units
|
21
|
+
pos = seq.pos.astype(int)
|
22
|
+
cord = jnp.concat((jnp.arange(pos.shape[0]).repeat(pos.shape[1])[..., None], pos.reshape(-1, 2)), axis=1).T
|
23
|
+
idxs = cord[:, seq.hp.flatten().astype(bool) > 0]
|
24
|
+
imgs = 1 - np.array(repeat(cfg.map, "... -> a ...", a=len(pos)).at[*idxs].set(1))
|
25
|
+
imgs = [Image.fromarray(img).resize(np.array(img.shape[:2]) * scale, Image.NEAREST) for img in imgs * 255] # type: ignore
|
26
|
+
imgs[0].save("/Users/nobr/desk/s3/btc2sim/sims.gif", save_all=True, append_images=imgs[1:], duration=10, loop=0)
|
27
|
+
|
28
|
+
|
29
|
+
def svg_fn(cfg: Config, seq, action, fname, targets=None, fps=2, debug=False):
|
30
|
+
# set up and background
|
31
|
+
e = esch.Drawing(h=cfg.size, w=cfg.size, row=1, col=seq.pos.shape[0], debug=debug, pad=10)
|
32
|
+
esch.grid_fn(repeat(np.array(cfg.map, dtype=float), f"... -> {seq.pos.shape[0]} ...") * 0.5, e, shape="square")
|
33
|
+
|
34
|
+
# loop thorugh teams
|
35
|
+
for i in jnp.unique(cfg.teams): # c#fg.teams.unique():
|
36
|
+
col = "red" if i == 1 else "blue"
|
37
|
+
|
38
|
+
# loop through types
|
39
|
+
for j in jnp.unique(cfg.types):
|
40
|
+
mask = (cfg.teams == i) & (cfg.types == j)
|
41
|
+
size, blast = float(cfg.rules.r[j]), float(cfg.rules.blast[j])
|
42
|
+
subset = np.array(rearrange(seq.pos, "a b c d -> a c d b"), dtype=float)[:, mask]
|
43
|
+
# print(tree.map(jnp.shape, action), mask.shape)
|
44
|
+
sub_action = tree.map(lambda x: x[:, :, mask], action)
|
45
|
+
# print(tree.map(jnp.shape, sub_action))
|
46
|
+
esch.sims_fn(e, subset, action=sub_action, fps=fps, col=col, stroke=col, size=size, blast=blast)
|
47
|
+
|
48
|
+
if debug:
|
49
|
+
sight, reach = float(cfg.rules.sight[j]), float(cfg.rules.reach[j])
|
50
|
+
esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=reach, stroke="grey")
|
51
|
+
esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=sight, stroke="yellow")
|
52
|
+
|
53
|
+
if targets is not None:
|
54
|
+
pos = np.array(repeat(targets, f"... -> {seq.pos.shape[0]} ..."))
|
55
|
+
arr = np.ones(pos.shape[:-1])
|
56
|
+
esch.mesh_fn(e, pos, arr, shape="square", col="purple")
|
57
|
+
|
58
|
+
# save
|
59
|
+
e.dwg.saveas(fname)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
[project]
|
2
|
+
name = "parabellum"
|
3
|
+
version = "0.0.73"
|
4
|
+
description = "Parabellum environment for parallel warfare simulation"
|
5
|
+
authors = [{ name = "Noah Syrkis", email = "desk@syrkis.com" }]
|
6
|
+
requires-python = ">=3.11,<3.12"
|
7
|
+
readme = "README.md"
|
8
|
+
dependencies = [
|
9
|
+
"jupyterlab>=4.2.2,<5",
|
10
|
+
"poetry>=1.8.3,<2",
|
11
|
+
"tqdm>=4.66.4,<5",
|
12
|
+
"geopy>=2.4.1,<3",
|
13
|
+
"osmnx==2.0.0b0",
|
14
|
+
"rasterio>=1.3.10,<2",
|
15
|
+
"ipykernel>=6.29.5,<7",
|
16
|
+
"folium>=0.17.0,<0.18",
|
17
|
+
"pandas>=2.2.2,<3",
|
18
|
+
"contextily>=1.6.0,<2",
|
19
|
+
"einops>=0.8.0,<0.9",
|
20
|
+
"jaxtyping>=0.2.33,<0.3",
|
21
|
+
"cartopy>=0.23.0,<0.24",
|
22
|
+
"stadiamaps>=3.2.1,<4",
|
23
|
+
"cachier>=3.1.2,<4",
|
24
|
+
"jax>=0.6.0,<0.7",
|
25
|
+
"gymnax>=0.0.8,<0.0.9",
|
26
|
+
"evosax>=0.1.6,<0.2",
|
27
|
+
"distrax>=0.1.5,<0.2",
|
28
|
+
"optax>=0.2.4,<0.3",
|
29
|
+
"flax>=0.10.4,<0.11",
|
30
|
+
"brax>=0.12.1,<0.13",
|
31
|
+
"wandb>=0.19.7,<0.20",
|
32
|
+
"flashbax>=0.1.2,<0.2",
|
33
|
+
"navix>=0.7.0,<0.8",
|
34
|
+
"omegaconf>=2.3.0,<3",
|
35
|
+
"jax-tqdm>=0.3.1,<0.4",
|
36
|
+
"ipython>=8.36.0",
|
37
|
+
"notebook>=7.4.2",
|
38
|
+
"equinox>=0.12.2",
|
39
|
+
"tensorboard>=2.19.0",
|
40
|
+
"tensorflow>=2.19.0",
|
41
|
+
"xprof>=2.20.0",
|
42
|
+
"jaxkd>=0.1.0",
|
43
|
+
"numpy>=2",
|
44
|
+
]
|
45
|
+
|
46
|
+
[dependency-groups]
|
47
|
+
dev = ["esch"]
|
48
|
+
|
49
|
+
[tool.uv]
|
50
|
+
|
51
|
+
[tool.uv.sources]
|
52
|
+
esch = { path = "../../esch" }
|
53
|
+
|
54
|
+
[build-system]
|
55
|
+
requires = ["hatchling"]
|
56
|
+
build-backend = "hatchling.build"
|
57
|
+
|
58
|
+
[tool.pyright]
|
59
|
+
venvPath = "."
|
60
|
+
venv = ".venv"
|
61
|
+
|
62
|
+
[tool.ruff]
|
63
|
+
line-length = 120
|