parabellum 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py
CHANGED
@@ -60,7 +60,7 @@ def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_a
|
|
60
60
|
else:
|
61
61
|
assert(len(allies_type) == n_allies)
|
62
62
|
allies = allies_type
|
63
|
-
|
63
|
+
|
64
64
|
if type(enemies_type) == int:
|
65
65
|
enemies = [enemies_type] * n_enemies
|
66
66
|
else:
|
@@ -149,9 +149,14 @@ class Environment(SMAX):
|
|
149
149
|
state = self._push_units_away(state) # type: ignore
|
150
150
|
obs = self.get_obs(state)
|
151
151
|
world_state = self.get_world_state(state)
|
152
|
-
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
152
|
+
# obs["world_state"] = jax.lax.stop_gradient(world_state)
|
153
153
|
return obs, state
|
154
154
|
|
155
|
+
def step_env(self, rng, state: State, action: Array):
|
156
|
+
obs, state, rewards, dones, infos = super().step_env(rng, state, action)
|
157
|
+
# delete world_state from obs
|
158
|
+
obs.pop("world_state")
|
159
|
+
return obs, state, rewards, dones, infos
|
155
160
|
|
156
161
|
def _our_push_units_away(
|
157
162
|
self, pos, unit_types, firmness: float = 1.0
|
@@ -198,7 +203,7 @@ class Environment(SMAX):
|
|
198
203
|
raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
|
199
204
|
raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
|
200
205
|
return jnp.any(raster)
|
201
|
-
|
206
|
+
|
202
207
|
def update_position(idx, vec):
|
203
208
|
# Compute the movements slightly strangely.
|
204
209
|
# The velocities below are for diagonal directions
|
@@ -337,7 +342,7 @@ class Environment(SMAX):
|
|
337
342
|
pos = jax.vmap(jnp.where)(clash, pos, new_pos)
|
338
343
|
# avoid going out of bounds
|
339
344
|
pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
|
340
|
-
|
345
|
+
|
341
346
|
# Multiple enemies can attack the same unit.
|
342
347
|
# We have `(health_diff, attacked_idx)` pairs.
|
343
348
|
# `jax.lax.scatter_add` aggregates these exactly
|
@@ -1,10 +1,10 @@
|
|
1
1
|
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
2
|
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
-
parabellum/env.py,sha256=
|
3
|
+
parabellum/env.py,sha256=H1YGAtUYNJd8OHnZ3sOEXbag5L0WjtJHBGL8ymGPvoE,16898
|
4
4
|
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
5
|
parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
|
6
6
|
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
7
|
parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
|
8
|
-
parabellum-0.2.
|
9
|
-
parabellum-0.2.
|
10
|
-
parabellum-0.2.
|
8
|
+
parabellum-0.2.26.dist-info/METADATA,sha256=AJwdmHRRPG2MosgufeWqQH7LsDRtTvFFGMh1azey9zA,2671
|
9
|
+
parabellum-0.2.26.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
+
parabellum-0.2.26.dist-info/RECORD,,
|
File without changes
|