parabellum 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py CHANGED
@@ -60,7 +60,7 @@ def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_a
60
60
  else:
61
61
  assert(len(allies_type) == n_allies)
62
62
  allies = allies_type
63
-
63
+
64
64
  if type(enemies_type) == int:
65
65
  enemies = [enemies_type] * n_enemies
66
66
  else:
@@ -149,9 +149,14 @@ class Environment(SMAX):
149
149
  state = self._push_units_away(state) # type: ignore
150
150
  obs = self.get_obs(state)
151
151
  world_state = self.get_world_state(state)
152
- obs["world_state"] = jax.lax.stop_gradient(world_state)
152
+ # obs["world_state"] = jax.lax.stop_gradient(world_state)
153
153
  return obs, state
154
154
 
155
+ def step_env(self, rng, state: State, action: Array):
156
+ obs, state, rewards, dones, infos = super().step_env(rng, state, action)
157
+ # delete world_state from obs
158
+ obs.pop("world_state")
159
+ return obs, state, rewards, dones, infos
155
160
 
156
161
  def _our_push_units_away(
157
162
  self, pos, unit_types, firmness: float = 1.0
@@ -198,7 +203,7 @@ class Environment(SMAX):
198
203
  raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
199
204
  raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
200
205
  return jnp.any(raster)
201
-
206
+
202
207
  def update_position(idx, vec):
203
208
  # Compute the movements slightly strangely.
204
209
  # The velocities below are for diagonal directions
@@ -337,7 +342,7 @@ class Environment(SMAX):
337
342
  pos = jax.vmap(jnp.where)(clash, pos, new_pos)
338
343
  # avoid going out of bounds
339
344
  pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
340
-
345
+
341
346
  # Multiple enemies can attack the same unit.
342
347
  # We have `(health_diff, attacked_idx)` pairs.
343
348
  # `jax.lax.scatter_add` aggregates these exactly
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.25
3
+ Version: 0.2.26
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -1,10 +1,10 @@
1
1
  parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
2
  parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
- parabellum/env.py,sha256=H1VwxHj8KqbjBJ8b7NMxkAg3Q0qwVzGulrigc26Tzkc,16663
3
+ parabellum/env.py,sha256=H1YGAtUYNJd8OHnZ3sOEXbag5L0WjtJHBGL8ymGPvoE,16898
4
4
  parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
5
  parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
6
6
  parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
7
  parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
8
- parabellum-0.2.25.dist-info/METADATA,sha256=UCuRLYhSUxnebs5pRs1FaWK67PUJw0kgsJ2UaaxNM9Q,2671
9
- parabellum-0.2.25.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
- parabellum-0.2.25.dist-info/RECORD,,
8
+ parabellum-0.2.26.dist-info/METADATA,sha256=AJwdmHRRPG2MosgufeWqQH7LsDRtTvFFGMh1azey9zA,2671
9
+ parabellum-0.2.26.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
+ parabellum-0.2.26.dist-info/RECORD,,