parabellum 0.3.4__tar.gz → 0.5.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {parabellum-0.3.4 → parabellum-0.5.13}/PKG-INFO +17 -15
- parabellum-0.5.13/parabellum/__init__.py +8 -0
- {parabellum-0.3.4 → parabellum-0.5.13}/parabellum/aid.py +16 -2
- parabellum-0.5.13/parabellum/env.py +113 -0
- {parabellum-0.3.4 → parabellum-0.5.13}/parabellum/geo.py +29 -35
- parabellum-0.5.13/parabellum/model.py +6 -0
- parabellum-0.5.13/parabellum/ppo.py +1 -0
- parabellum-0.5.13/parabellum/terrain_db.py +166 -0
- parabellum-0.5.13/parabellum/types.py +54 -0
- parabellum-0.5.13/pyproject.toml +53 -0
- parabellum-0.3.4/parabellum/__init__.py +0 -22
- parabellum-0.3.4/parabellum/env.py +0 -570
- parabellum-0.3.4/parabellum/terrain_db.py +0 -134
- parabellum-0.3.4/parabellum/tps.py +0 -17
- parabellum-0.3.4/pyproject.toml +0 -62
- {parabellum-0.3.4 → parabellum-0.5.13}/README.md +0 -0
- {parabellum-0.3.4 → parabellum-0.5.13}/parabellum/gun.py +0 -0
- {parabellum-0.3.4 → parabellum-0.5.13}/parabellum/pcg.py +0 -0
- {parabellum-0.3.4 → parabellum-0.5.13}/parabellum/run.py +0 -0
- {parabellum-0.3.4 → parabellum-0.5.13}/parabellum/vis.py +0 -0
@@ -1,39 +1,41 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.3
|
2
2
|
Name: parabellum
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.13
|
4
4
|
Summary: Parabellum environment for parallel warfare simulation
|
5
|
-
Home-page: https://github.com/syrkis/parabellum
|
6
|
-
License: MIT
|
7
|
-
Keywords: warfare,simulation,parallel,environment
|
8
5
|
Author: Noah Syrkis
|
9
6
|
Author-email: desk@syrkis.com
|
10
7
|
Requires-Python: >=3.11,<3.12
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
12
8
|
Classifier: Programming Language :: Python :: 3
|
13
9
|
Classifier: Programming Language :: Python :: 3.11
|
10
|
+
Requires-Dist: brax (>=0.12.1,<0.13.0)
|
11
|
+
Requires-Dist: cachier (>=3.1.2,<4.0.0)
|
14
12
|
Requires-Dist: cartopy (>=0.23.0,<0.24.0)
|
15
13
|
Requires-Dist: contextily (>=1.6.0,<2.0.0)
|
16
|
-
Requires-Dist:
|
14
|
+
Requires-Dist: distrax (>=0.1.5,<0.2.0)
|
17
15
|
Requires-Dist: einops (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: equinox (>=0.11.11,<0.12.0)
|
17
|
+
Requires-Dist: evosax (>=0.1.6,<0.2.0)
|
18
|
+
Requires-Dist: flashbax (>=0.1.2,<0.2.0)
|
19
|
+
Requires-Dist: flax (>=0.10.4,<0.11.0)
|
18
20
|
Requires-Dist: folium (>=0.17.0,<0.18.0)
|
19
21
|
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
22
|
+
Requires-Dist: gymnax (>=0.0.8,<0.0.9)
|
20
23
|
Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
|
21
|
-
Requires-Dist: jax (
|
22
|
-
Requires-Dist:
|
24
|
+
Requires-Dist: jax (>=0.5.0,<0.6.0)
|
25
|
+
Requires-Dist: jax-tqdm (>=0.3.1,<0.4.0)
|
23
26
|
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
24
27
|
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
25
|
-
Requires-Dist:
|
26
|
-
Requires-Dist: numpy (
|
27
|
-
Requires-Dist:
|
28
|
+
Requires-Dist: navix (>=0.7.0,<0.8.0)
|
29
|
+
Requires-Dist: numpy (>=2.2.3,<3.0.0)
|
30
|
+
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
31
|
+
Requires-Dist: optax (>=0.2.4,<0.3.0)
|
28
32
|
Requires-Dist: osmnx (==2.0.0b0)
|
29
33
|
Requires-Dist: pandas (>=2.2.2,<3.0.0)
|
30
34
|
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
31
|
-
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
32
35
|
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
33
|
-
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
34
36
|
Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
|
35
37
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
36
|
-
|
38
|
+
Requires-Dist: wandb (>=0.19.7,<0.20.0)
|
37
39
|
Description-Content-Type: text/markdown
|
38
40
|
|
39
41
|
# Parabellum
|
@@ -3,10 +3,9 @@
|
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# imports
|
6
|
-
import os
|
7
6
|
from collections import namedtuple
|
8
|
-
from typing import Tuple
|
9
7
|
import cartopy.crs as ccrs
|
8
|
+
import jax.numpy as jnp
|
10
9
|
|
11
10
|
# types
|
12
11
|
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
@@ -23,5 +22,20 @@ def to_mercator(bbox: BBox) -> BBox:
|
|
23
22
|
def to_platecarree(bbox: BBox) -> BBox:
|
24
23
|
proj = ccrs.PlateCarree()
|
25
24
|
west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
|
25
|
+
|
26
26
|
east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
|
27
27
|
return BBox(north=north, south=south, east=east, west=west)
|
28
|
+
|
29
|
+
|
30
|
+
def obstacle_mask_fn(limit):
|
31
|
+
def aux(i, j):
|
32
|
+
xs = jnp.linspace(0, i + 1, i + j + 1)
|
33
|
+
ys = jnp.linspace(0, j + 1, i + j + 1)
|
34
|
+
cc = jnp.stack((xs, ys)).astype(jnp.int8)
|
35
|
+
mask = jnp.zeros((limit, limit)).at[*cc].set(1)
|
36
|
+
return mask
|
37
|
+
|
38
|
+
x = jnp.repeat(jnp.arange(limit), limit)
|
39
|
+
y = jnp.tile(jnp.arange(limit), limit)
|
40
|
+
mask = jnp.stack([aux(*c) for c in jnp.stack((x, y)).T])
|
41
|
+
return mask.astype(jnp.int8).reshape(limit, limit, limit, limit)
|
@@ -0,0 +1,113 @@
|
|
1
|
+
# env.py
|
2
|
+
# parabellum env
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# % Imports
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from jax import random, Array, lax, vmap, debug
|
8
|
+
import jax.numpy.linalg as la
|
9
|
+
from typing import Tuple
|
10
|
+
from functools import partial
|
11
|
+
|
12
|
+
from parabellum.geo import geography_fn
|
13
|
+
from parabellum.types import Action, State, Obs, Scene
|
14
|
+
from parabellum import aid
|
15
|
+
import equinox as eqx
|
16
|
+
|
17
|
+
|
18
|
+
# %% Dataclass ################################################################
|
19
|
+
class Env:
|
20
|
+
def __init__(self, cfg):
|
21
|
+
self.cfg = cfg
|
22
|
+
|
23
|
+
def reset(self, rng: Array, scene: Scene) -> Tuple[Obs, State]:
|
24
|
+
return init_fn(rng, self, scene)
|
25
|
+
|
26
|
+
def step(self, rng: Array, scene: Scene, state: State, action: Action) -> Tuple[Obs, State]:
|
27
|
+
return obs_fn(self, scene, state), step_fn(rng, self, scene, state, action)
|
28
|
+
|
29
|
+
@property
|
30
|
+
def num_units(self):
|
31
|
+
return sum(self.cfg.counts.allies.values()) + sum(self.cfg.counts.enemies.values())
|
32
|
+
|
33
|
+
@property
|
34
|
+
def num_allies(self):
|
35
|
+
return sum(self.cfg.counts.allies.values())
|
36
|
+
|
37
|
+
@property
|
38
|
+
def num_enemies(self):
|
39
|
+
return sum(self.cfg.counts.enemies.values())
|
40
|
+
|
41
|
+
|
42
|
+
# %% Functions ################################################################
|
43
|
+
@eqx.filter_jit
|
44
|
+
def init_fn(rng: Array, env: Env, scene: Scene) -> Tuple[Obs, State]: # initialize -----
|
45
|
+
keys = random.split(rng)
|
46
|
+
health = jnp.ones(env.num_units) * scene.unit_type_health[scene.unit_types]
|
47
|
+
pos = random.normal(keys[1], (scene.unit_types.size, 2)) * 2 + env.cfg.size / 2
|
48
|
+
state = State(unit_position=pos, unit_health=health, unit_cooldown=jnp.zeros(env.num_units)) # state --
|
49
|
+
return obs_fn(env, scene, state), state # return observation and state of agents --
|
50
|
+
|
51
|
+
|
52
|
+
@eqx.filter_jit # knn from env.cfg never changes, so we can jit it
|
53
|
+
def obs_fn(env, scene: Scene, state: State) -> Obs: # return info about neighbors ---
|
54
|
+
distances = la.norm(state.unit_position[:, None] - state.unit_position, axis=-1) # all dist --
|
55
|
+
dists, idxs = lax.approx_min_k(distances, k=env.cfg.knn)
|
56
|
+
mask = mask_fn(scene, state, dists, idxs)
|
57
|
+
health = state.unit_health[idxs] * mask
|
58
|
+
cooldown = state.unit_cooldown[idxs] * mask
|
59
|
+
unit_pos = (state.unit_position[:, None, ...] - state.unit_position[idxs]) * mask[..., None]
|
60
|
+
return Obs(unit_id=idxs, unit_pos=unit_pos, unit_health=health, unit_cooldown=cooldown)
|
61
|
+
|
62
|
+
|
63
|
+
@eqx.filter_jit
|
64
|
+
def step_fn(rng, env: Env, scene: Scene, state: State, action: Action) -> State: # update agents ---
|
65
|
+
newpos = state.unit_position + action.coord * (1 - action.kinds[..., None])
|
66
|
+
bounds = ((newpos < 0).any(axis=-1) | (newpos >= env.cfg.size).any(axis=-1))[..., None]
|
67
|
+
builds = (scene.terrain.building[*newpos.astype(jnp.int32).T] > 0)[..., None]
|
68
|
+
newpos = jnp.where(bounds | builds, state.unit_position, newpos) # use old pos if new is not valid
|
69
|
+
new_hp = blast_fn(rng, env, scene, state, action)
|
70
|
+
return State(unit_position=newpos, unit_health=new_hp, unit_cooldown=state.unit_cooldown) # return -
|
71
|
+
|
72
|
+
|
73
|
+
def blast_fn(rng, env: Env, scene: Scene, state: State, action: Action): # update agents ---
|
74
|
+
dist = la.norm(state.unit_position[None, ...] - (state.unit_position + action.coord)[:, None, ...], axis=-1)
|
75
|
+
hits = dist <= scene.unit_type_reach[scene.unit_types][None, ...] * action.kinds[..., None] # mask non attack act
|
76
|
+
damage = (hits * scene.unit_type_damage[scene.unit_types][None, ...]).sum(axis=-1)
|
77
|
+
return state.unit_health - damage
|
78
|
+
|
79
|
+
|
80
|
+
# @eqx.filter_jit
|
81
|
+
def scene_fn(cfg): # init's a scene
|
82
|
+
aux = lambda key: jnp.array([x[key] for x in sorted(cfg.types, key=lambda x: x.name)]) # noqa
|
83
|
+
attrs = ["health", "damage", "reload", "reach", "sight", "speed"]
|
84
|
+
kwargs = {f"unit_type_{a}": aux(a) for a in attrs} | {"terrain": geography_fn(cfg.place, cfg.size)}
|
85
|
+
num_allies, num_enemies = sum(cfg.counts.allies.values()), sum(cfg.counts.enemies.values())
|
86
|
+
unit_teams = jnp.concat((jnp.zeros(num_allies), jnp.ones(num_enemies))).astype(jnp.int32)
|
87
|
+
aux = lambda t: jnp.concat([jnp.zeros(x) + i for i, x in enumerate([x[1] for x in sorted(cfg.counts[t].items())])]) # noqa
|
88
|
+
unit_types = jnp.concat((aux("allies"), aux("enemies"))).astype(jnp.int32)
|
89
|
+
mask = aid.obstacle_mask_fn(max([x["sight"] for x in cfg.types]))
|
90
|
+
return Scene(unit_teams=unit_teams, unit_types=unit_types, mask=mask, **kwargs) # type: ignore
|
91
|
+
|
92
|
+
|
93
|
+
@eqx.filter_jit
|
94
|
+
def mask_fn(scene, state, dists, idxs):
|
95
|
+
mask = dists < scene.unit_type_sight[scene.unit_types][..., None] # mask for removing hidden
|
96
|
+
mask = mask | obstacle_fn(scene, state.unit_position[idxs].astype(jnp.int8))
|
97
|
+
return mask
|
98
|
+
|
99
|
+
|
100
|
+
@partial(vmap, in_axes=(None, 0)) # 5 x 2 # not the best name for a fn
|
101
|
+
def obstacle_fn(scene, pos):
|
102
|
+
slice = slice_fn(scene, pos[0], pos)
|
103
|
+
return slice
|
104
|
+
|
105
|
+
|
106
|
+
@partial(vmap, in_axes=(None, None, 0))
|
107
|
+
def slice_fn(scene, source, target): # returns a 10 x 10 view with unit at top left corner, and terrain downwards
|
108
|
+
delta = ((source - target) >= 0) * 2 - 1
|
109
|
+
coord = jnp.sort(jnp.stack((source, source + delta * 10)), axis=0)[0]
|
110
|
+
slice = lax.dynamic_slice(scene.terrain.building, coord, (scene.mask.shape[-1], scene.mask.shape[-1]))
|
111
|
+
slice = lax.cond(delta[0] == 1, lambda: jnp.flip(slice), lambda: slice)
|
112
|
+
slice = lax.cond(delta[1] == 1, lambda: jnp.flip(slice, axis=1), lambda: slice)
|
113
|
+
return (scene.mask[*jnp.abs(source - target)] * slice).sum() == 0
|
@@ -3,25 +3,23 @@
|
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# %% Imports
|
6
|
-
from parabellum import tps
|
7
|
-
import rasterio
|
8
6
|
from rasterio import features, transform
|
7
|
+
from jax import tree
|
9
8
|
from geopy.geocoders import Nominatim
|
10
9
|
from geopy.distance import distance
|
11
10
|
import contextily as cx
|
12
11
|
import jax.numpy as jnp
|
13
12
|
import cartopy.crs as ccrs
|
14
13
|
from jaxtyping import Array
|
15
|
-
import numpy as np
|
16
14
|
from shapely import box
|
17
15
|
import osmnx as ox
|
18
16
|
import geopandas as gpd
|
19
17
|
from collections import namedtuple
|
20
18
|
from typing import Tuple
|
21
19
|
import matplotlib.pyplot as plt
|
22
|
-
|
23
|
-
import os
|
20
|
+
from cachier import cachier
|
24
21
|
from jax.scipy.signal import convolve
|
22
|
+
from parabellum.types import Terrain
|
25
23
|
|
26
24
|
# %% Types
|
27
25
|
Coords = Tuple[float, float]
|
@@ -79,27 +77,23 @@ def basemap_fn(bbox: BBox, gdf) -> Array:
|
|
79
77
|
return image
|
80
78
|
|
81
79
|
|
82
|
-
|
80
|
+
@cachier()
|
81
|
+
def geography_fn(place, buffer) -> Terrain:
|
83
82
|
bbox = get_bbox(place, buffer)
|
84
83
|
map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
|
85
84
|
gdf = gpd.GeoDataFrame(map_data)
|
86
|
-
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
|
87
|
-
"EPSG:3857"
|
88
|
-
)
|
85
|
+
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
|
89
86
|
raster = raster_fn(gdf, shape=(buffer, buffer))
|
90
87
|
basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
|
91
|
-
# 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
|
92
88
|
kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
93
|
-
trans = lambda x: jnp.rot90(x, 3)
|
94
|
-
|
95
|
-
terrain = tps.Terrain(
|
89
|
+
trans = lambda x: jnp.rot90(x, 3) # noqa
|
90
|
+
terrain = Terrain(
|
96
91
|
building=trans(raster[0]),
|
97
|
-
water=trans(
|
98
|
-
raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
|
99
|
-
),
|
92
|
+
water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
|
100
93
|
forest=trans(jnp.logical_or(raster[3], raster[4])),
|
101
94
|
basemap=basemap,
|
102
95
|
)
|
96
|
+
terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
|
103
97
|
return terrain
|
104
98
|
|
105
99
|
|
@@ -128,25 +122,25 @@ def feature_fn(t, feature, gdf, shape):
|
|
128
122
|
|
129
123
|
|
130
124
|
# %%
|
131
|
-
def normalize(x):
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
def get_bridges(gdf):
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
125
|
+
# def normalize(x):
|
126
|
+
# return (np.array(x) - m) / (M - m)
|
127
|
+
|
128
|
+
|
129
|
+
# def get_bridges(gdf):
|
130
|
+
# xmin, ymin, xmax, ymax = gdf.total_bounds
|
131
|
+
# m = np.array([xmin, ymin])
|
132
|
+
# M = np.array([xmax, ymax])
|
133
|
+
|
134
|
+
# bridges = {}
|
135
|
+
# for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
|
136
|
+
# if type(bridge["name"]) == str:
|
137
|
+
# bridges[idx[1]] = {
|
138
|
+
# "name": bridge["name"],
|
139
|
+
# "coords": normalize(
|
140
|
+
# [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
|
141
|
+
# ),
|
142
|
+
# }
|
143
|
+
# return bridges
|
150
144
|
|
151
145
|
|
152
146
|
"""
|
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -0,0 +1,166 @@
|
|
1
|
+
# %%
|
2
|
+
import numpy as np
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from parabellum.types import Terrain
|
5
|
+
|
6
|
+
|
7
|
+
# %%
|
8
|
+
def map_raster_from_line(raster, line, size):
|
9
|
+
x0, y0, dx, dy = line
|
10
|
+
x0 = int(x0 * size)
|
11
|
+
y0 = int(y0 * size)
|
12
|
+
dx = int(dx * size)
|
13
|
+
dy = int(dy * size)
|
14
|
+
max_T = int(2**0.5 * size)
|
15
|
+
for t in range(max_T + 1):
|
16
|
+
alpha = t / float(max_T)
|
17
|
+
x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0 + dx))
|
18
|
+
y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0 + dy))
|
19
|
+
if 0 <= x < size and 0 <= y < size:
|
20
|
+
raster[x, y] = 1
|
21
|
+
return raster
|
22
|
+
|
23
|
+
|
24
|
+
# %%
|
25
|
+
def map_raster_from_rect(raster, rect, size):
|
26
|
+
x0, y0, dx, dy = rect
|
27
|
+
x0 = int(x0 * size)
|
28
|
+
y0 = int(y0 * size)
|
29
|
+
dx = int(dx * size)
|
30
|
+
dy = int(dy * size)
|
31
|
+
raster[x0 : x0 + dx, y0 : y0 + dy] = 1
|
32
|
+
return raster
|
33
|
+
|
34
|
+
|
35
|
+
# %%
|
36
|
+
building_color = jnp.array([201, 199, 198, 255])
|
37
|
+
water_color = jnp.array([193, 237, 254, 255])
|
38
|
+
forest_color = jnp.array([197, 214, 185, 255])
|
39
|
+
empty_color = jnp.array([255, 255, 255, 255])
|
40
|
+
|
41
|
+
|
42
|
+
def make_terrain(terrain_args, size):
|
43
|
+
args = {}
|
44
|
+
for key, config in terrain_args.items():
|
45
|
+
raster = np.zeros((size, size))
|
46
|
+
if config is not None:
|
47
|
+
for elem in config:
|
48
|
+
if "line" in elem:
|
49
|
+
raster = map_raster_from_line(raster, elem["line"], size)
|
50
|
+
elif "rect" in elem:
|
51
|
+
raster = map_raster_from_rect(raster, elem["rect"], size)
|
52
|
+
args[key] = jnp.array(raster.T)
|
53
|
+
basemap = jnp.where(
|
54
|
+
args["building"][:, :, None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size, size, 1))
|
55
|
+
)
|
56
|
+
basemap = jnp.where(args["water"][:, :, None], jnp.tile(water_color, (size, size, 1)), basemap)
|
57
|
+
basemap = jnp.where(args["forest"][:, :, None], jnp.tile(forest_color, (size, size, 1)), basemap)
|
58
|
+
args["basemap"] = basemap
|
59
|
+
return Terrain(**args)
|
60
|
+
|
61
|
+
|
62
|
+
# %%
|
63
|
+
db = {
|
64
|
+
"blank": {"building": None, "water": None, "forest": None},
|
65
|
+
"F": {
|
66
|
+
"building": [
|
67
|
+
{"line": [0.25, 0.33, 0.5, 0]},
|
68
|
+
{"line": [0.75, 0.33, 0.0, 0.25]},
|
69
|
+
{"line": [0.50, 0.33, 0.0, 0.25]},
|
70
|
+
],
|
71
|
+
"water": None,
|
72
|
+
"forest": None,
|
73
|
+
},
|
74
|
+
"stronghold": {
|
75
|
+
"building": [
|
76
|
+
{"line": [0.2, 0.275, 0.2, 0.0]},
|
77
|
+
{"line": [0.2, 0.275, 0.0, 0.2]},
|
78
|
+
{"line": [0.4, 0.275, 0.0, 0.2]},
|
79
|
+
{"line": [0.2, 0.475, 0.2, 0.0]},
|
80
|
+
{"line": [0.2, 0.525, 0.2, 0.0]},
|
81
|
+
{"line": [0.2, 0.525, 0.0, 0.2]},
|
82
|
+
{"line": [0.4, 0.525, 0.0, 0.2]},
|
83
|
+
{"line": [0.2, 0.725, 0.525, 0.0]},
|
84
|
+
{"line": [0.75, 0.25, 0.0, 0.2]},
|
85
|
+
{"line": [0.75, 0.55, 0.0, 0.19]},
|
86
|
+
{"line": [0.6, 0.25, 0.15, 0.0]},
|
87
|
+
],
|
88
|
+
"water": None,
|
89
|
+
"forest": None,
|
90
|
+
},
|
91
|
+
"playground": {"building": [{"line": [0.5, 0.5, 0.5, 0.0]}], "water": None, "forest": None},
|
92
|
+
"playground2": {
|
93
|
+
"building": [],
|
94
|
+
"water": [{"rect": [0.0, 0.8, 0.1, 0.1]}, {"rect": [0.2, 0.8, 0.8, 0.1]}],
|
95
|
+
"forest": [{"rect": [0.0, 0.0, 1.0, 0.2]}],
|
96
|
+
},
|
97
|
+
"triangle": {
|
98
|
+
"building": [{"line": [0.33, 0.0, 0.0, 1.0]}, {"line": [0.66, 0.0, 0.0, 1.0]}],
|
99
|
+
"water": None,
|
100
|
+
"forest": None,
|
101
|
+
},
|
102
|
+
"u_shape": {
|
103
|
+
"building": [],
|
104
|
+
"water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
|
105
|
+
"forest": [],
|
106
|
+
},
|
107
|
+
"bridges": {
|
108
|
+
"building": [],
|
109
|
+
"water": [
|
110
|
+
{"rect": [0.475, 0.0, 0.05, 0.1]},
|
111
|
+
{"rect": [0.475, 0.15, 0.05, 0.575]},
|
112
|
+
{"rect": [0.475, 0.775, 0.05, 1.0]},
|
113
|
+
{"rect": [0.0, 0.475, 0.225, 0.05]},
|
114
|
+
{"rect": [0.275, 0.475, 0.45, 0.05]},
|
115
|
+
{"rect": [0.775, 0.475, 0.23, 0.05]},
|
116
|
+
],
|
117
|
+
"forest": [
|
118
|
+
{"rect": [0.1, 0.625, 0.275, 0.275]},
|
119
|
+
{"rect": [0.725, 0.0, 0.3, 0.275]},
|
120
|
+
],
|
121
|
+
},
|
122
|
+
}
|
123
|
+
|
124
|
+
# %% [raw]
|
125
|
+
# import matplotlib.pyplot as plt
|
126
|
+
# size = 100
|
127
|
+
# raster = np.zeros((size, size))
|
128
|
+
# rect = [0.475, 0., 0.05, 0.1]
|
129
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
130
|
+
# rect = [0.475, 0.15, 0.05, 0.575]
|
131
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
132
|
+
# rect = [0.475, 0.775, 0.05, 1.]
|
133
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
134
|
+
#
|
135
|
+
# rect = [0., 0.475, 0.225, 0.05]
|
136
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
137
|
+
# rect = [0.275, 0.475, 0.45, 0.05]
|
138
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
139
|
+
# rect = [0.775, 0.475, 0.23, 0.05]
|
140
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
141
|
+
#
|
142
|
+
# rect = [0.1, 0.625, 0.275, 0.275]
|
143
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
144
|
+
# rect = [0.725, 0., 0.3, 0.275]
|
145
|
+
# raster = map_raster_from_rect(raster, rect, size)
|
146
|
+
#
|
147
|
+
# plt.imshow(raster[::-1, :])
|
148
|
+
|
149
|
+
# %% [markdown]
|
150
|
+
# # Main
|
151
|
+
|
152
|
+
# %%
|
153
|
+
if __name__ == "__main__":
|
154
|
+
import matplotlib.pyplot as plt
|
155
|
+
|
156
|
+
# %%
|
157
|
+
terrain = make_terrain(db["bridges"], size=100)
|
158
|
+
|
159
|
+
# %%
|
160
|
+
plt.imshow(jnp.rot90(terrain.basemap))
|
161
|
+
bl = (39.5, 5)
|
162
|
+
tr = (44.5, 10)
|
163
|
+
plt.scatter(bl[0], 49 - bl[1])
|
164
|
+
plt.scatter(tr[0], 49 - tr[1], marker="+")
|
165
|
+
|
166
|
+
# %%
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# types.py
|
2
|
+
# parabellum types
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
from chex import dataclass
|
7
|
+
from jaxtyping import Array, Bool, Float16
|
8
|
+
|
9
|
+
|
10
|
+
# dataclasses
|
11
|
+
@dataclass
|
12
|
+
class State:
|
13
|
+
unit_position: Array
|
14
|
+
unit_health: Array
|
15
|
+
unit_cooldown: Array
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class Obs:
|
20
|
+
unit_id: Array
|
21
|
+
unit_pos: Array
|
22
|
+
unit_health: Array
|
23
|
+
unit_cooldown: Array
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class Action:
|
28
|
+
coord: Float16[Array, "... 2"] # noqa
|
29
|
+
kinds: Bool[Array, "..."]
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class Terrain:
|
34
|
+
building: Array
|
35
|
+
water: Array
|
36
|
+
forest: Array
|
37
|
+
basemap: Array
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class Scene:
|
42
|
+
terrain: Terrain
|
43
|
+
mask: Array
|
44
|
+
|
45
|
+
unit_types: Array
|
46
|
+
unit_teams: Array
|
47
|
+
|
48
|
+
unit_type_health: Array
|
49
|
+
unit_type_damage: Array
|
50
|
+
unit_type_reload: Array
|
51
|
+
|
52
|
+
unit_type_reach: Array
|
53
|
+
unit_type_sight: Array
|
54
|
+
unit_type_speed: Array
|
@@ -0,0 +1,53 @@
|
|
1
|
+
[tool.poetry]
|
2
|
+
name = "parabellum"
|
3
|
+
version = "0.5.13"
|
4
|
+
description = "Parabellum environment for parallel warfare simulation"
|
5
|
+
authors = ["Noah Syrkis <desk@syrkis.com>"]
|
6
|
+
readme = "README.md"
|
7
|
+
|
8
|
+
[tool.poetry.dependencies]
|
9
|
+
python = "^3.11,<3.12"
|
10
|
+
jupyterlab = "^4.2.2"
|
11
|
+
poetry = "^1.8.3"
|
12
|
+
tqdm = "^4.66.4"
|
13
|
+
geopy = "^2.4.1"
|
14
|
+
osmnx = "2.0.0b0"
|
15
|
+
rasterio = "^1.3.10"
|
16
|
+
ipykernel = "^6.29.5"
|
17
|
+
folium = "^0.17.0"
|
18
|
+
pandas = "^2.2.2"
|
19
|
+
contextily = "^1.6.0"
|
20
|
+
einops = "^0.8.0"
|
21
|
+
jaxtyping = "^0.2.33"
|
22
|
+
cartopy = "^0.23.0"
|
23
|
+
stadiamaps = "^3.2.1"
|
24
|
+
cachier = "^3.1.2"
|
25
|
+
equinox = "^0.11.11"
|
26
|
+
jax = "^0.5.0"
|
27
|
+
gymnax = "^0.0.8"
|
28
|
+
evosax = "^0.1.6"
|
29
|
+
distrax = "^0.1.5"
|
30
|
+
optax = "^0.2.4"
|
31
|
+
flax = "^0.10.4"
|
32
|
+
numpy = "^2.2.3"
|
33
|
+
brax = "^0.12.1"
|
34
|
+
wandb = "^0.19.7"
|
35
|
+
flashbax = "^0.1.2"
|
36
|
+
navix = "^0.7.0"
|
37
|
+
omegaconf = "^2.3.0"
|
38
|
+
jax-tqdm = "^0.3.1"
|
39
|
+
|
40
|
+
[tool.poetry.group.dev.dependencies]
|
41
|
+
esch = { path = "../../esch" }
|
42
|
+
purejaxrl = { path = "../purejaxrl" }
|
43
|
+
|
44
|
+
[build-system]
|
45
|
+
requires = ["poetry-core"]
|
46
|
+
build-backend = "poetry.core.masonry.api"
|
47
|
+
|
48
|
+
[tool.pyright]
|
49
|
+
venvPath = "."
|
50
|
+
venv = ".venv"
|
51
|
+
|
52
|
+
[tool.ruff]
|
53
|
+
line-length = 120
|
@@ -1,22 +0,0 @@
|
|
1
|
-
from .env import Environment, Scenario, make_scenario, State
|
2
|
-
from .vis import Visualizer, Skin
|
3
|
-
from .gun import bullet_fn
|
4
|
-
from . import vis
|
5
|
-
from . import terrain_db
|
6
|
-
from . import env
|
7
|
-
from . import tps
|
8
|
-
# from .run import run
|
9
|
-
|
10
|
-
__all__ = [
|
11
|
-
"env",
|
12
|
-
"terrain_db",
|
13
|
-
"vis",
|
14
|
-
"tps",
|
15
|
-
"Environment",
|
16
|
-
"Scenario",
|
17
|
-
"make_scenario",
|
18
|
-
"State",
|
19
|
-
"Visualizer",
|
20
|
-
"Skin",
|
21
|
-
"bullet_fn",
|
22
|
-
]
|