parabellum 0.0.0__py3-none-any.whl → 0.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/types.py ADDED
@@ -0,0 +1,138 @@
1
+ # types.py
2
+ # parabellum types
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ from chex import dataclass
7
+ from jaxtyping import Array, Float32, Int
8
+ import jax.numpy as jnp
9
+ from parabellum.geo import geography_fn
10
+ from dataclasses import field
11
+
12
+
13
+ @dataclass
14
+ class Kind:
15
+ hp: int
16
+ dam: int
17
+ speed: int
18
+ reach: int
19
+ sight: int
20
+ blast: int
21
+ r: float
22
+
23
+
24
+ @dataclass
25
+ class Rules:
26
+ troop = Kind(hp=120, dam=15, speed=2, reach=4, sight=4, blast=1, r=1)
27
+ armor = Kind(hp=150, dam=12, speed=1, reach=8, sight=16, blast=3, r=2)
28
+ plane = Kind(hp=80, dam=20, speed=4, reach=16, sight=32, blast=2, r=2)
29
+ civil = Kind(hp=100, dam=0, speed=3, reach=5, sight=10, blast=1, r=2)
30
+ medic = Kind(hp=100, dam=-10, speed=3, reach=5, sight=10, blast=1, r=2)
31
+
32
+ def __post_init__(self):
33
+ self.hp = jnp.array((self.troop.hp, self.armor.hp, self.plane.hp, self.civil.hp, self.medic.hp))
34
+ self.dam = jnp.array((self.troop.dam, self.armor.dam, self.plane.dam, self.civil.dam, self.medic.dam))
35
+ self.r = jnp.array((self.troop.r, self.armor.r, self.plane.r, self.civil.r, self.medic.r))
36
+ self.speed = jnp.array(
37
+ (self.troop.speed, self.armor.speed, self.plane.speed, self.civil.speed, self.medic.speed)
38
+ )
39
+ self.reach = jnp.array(
40
+ (self.troop.reach, self.armor.reach, self.plane.reach, self.civil.reach, self.medic.reach)
41
+ )
42
+ self.sight = jnp.array(
43
+ (self.troop.sight, self.armor.sight, self.plane.sight, self.civil.sight, self.medic.sight)
44
+ )
45
+ self.blast = jnp.array(
46
+ (self.troop.blast, self.armor.blast, self.plane.blast, self.civil.blast, self.medic.blast)
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class Team:
52
+ troop: int = 1
53
+ armor: int = 0
54
+ plane: int = 0
55
+ civil: int = 0
56
+ medic: int = 0
57
+
58
+ def __post_init__(self):
59
+ self.length: int = self.troop + self.armor + self.plane + self.civil + self.medic
60
+ self.types: Array = jnp.repeat(
61
+ jnp.arange(5), jnp.array((self.troop, self.armor, self.plane, self.civil, self.medic))
62
+ )
63
+
64
+
65
+ # dataclasses
66
+ @dataclass
67
+ class State:
68
+ pos: Array
69
+ hp: Array
70
+ # target: Array
71
+
72
+
73
+ @dataclass
74
+ class Obs:
75
+ # idxs: Array
76
+ hp: Array
77
+ pos: Array
78
+ type: Array
79
+ team: Array
80
+ dist: Array
81
+ mask: Array
82
+ reach: Array
83
+ sight: Array
84
+ speed: Array
85
+
86
+ @property
87
+ def ally(self):
88
+ return (self.team == self.team[0]) & self.mask
89
+
90
+ @property
91
+ def enemy(self):
92
+ return (self.team != self.team[0]) & self.mask
93
+
94
+
95
+ @dataclass
96
+ class Action:
97
+ pos: Array
98
+ kind: Int[Array, "..."] # 0 = invalid, 1 = move, 2 = cast
99
+
100
+ @property
101
+ def invalid(self):
102
+ return self.kind == 0
103
+
104
+ @property
105
+ def move(self):
106
+ return self.kind == 1
107
+
108
+ @property
109
+ def cast(self): # cast bomb, bullet or medicin
110
+ return self.kind == 2
111
+
112
+
113
+ @dataclass
114
+ class Config: # Remove frozen=True for now
115
+ steps: int = 123
116
+ place: str = "Palazzo della Civiltà Italiana, Rome, Italy"
117
+ force: float = 0.5
118
+ sims: int = 2
119
+ size: int = 64
120
+ knn: int = 2
121
+ blu: Team = field(default_factory=lambda: Team())
122
+ red: Team = field(default_factory=lambda: Team())
123
+ rules: Rules = field(default_factory=lambda: Rules())
124
+
125
+ def __post_init__(self):
126
+ # Pre-compute everything once
127
+ self.types: Array = jnp.concat((self.blu.types, self.red.types))
128
+ self.teams: Array = jnp.repeat(jnp.arange(2), jnp.array((self.blu.length, self.red.length)))
129
+ self.map: Array = geography_fn(self.place, self.size) # Computed once here
130
+ self.hp: Array = self.rules.hp
131
+ self.dam: Array = self.rules.dam
132
+ self.r: Array = self.rules.r
133
+ self.speed: Array = self.rules.speed
134
+ self.reach: Array = self.rules.reach
135
+ self.sight: Array = self.rules.sight
136
+ self.blast: Array = self.rules.blast
137
+ self.length: int = self.blu.length + self.red.length
138
+ self.root: Array = jnp.int32(jnp.sqrt(self.length))
parabellum/utils.py ADDED
@@ -0,0 +1,59 @@
1
+ # %% utils.py
2
+ # parabellum ut
3
+
4
+
5
+ # Imports
6
+ import esch
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from einops import rearrange, repeat
10
+ from jax import tree
11
+ from PIL import Image
12
+ from parabellum.types import Config
13
+
14
+ # Twilight colors (used in neurocope)
15
+ red = "#EA344A"
16
+ blue = "#2B60F6"
17
+
18
+
19
+ # %% Plotting
20
+ def gif_fn(cfg: Config, seq, scale=4): # animate positions TODO: remove dead units
21
+ pos = seq.pos.astype(int)
22
+ cord = jnp.concat((jnp.arange(pos.shape[0]).repeat(pos.shape[1])[..., None], pos.reshape(-1, 2)), axis=1).T
23
+ idxs = cord[:, seq.hp.flatten().astype(bool) > 0]
24
+ imgs = 1 - np.array(repeat(cfg.map, "... -> a ...", a=len(pos)).at[*idxs].set(1))
25
+ imgs = [Image.fromarray(img).resize(np.array(img.shape[:2]) * scale, Image.NEAREST) for img in imgs * 255] # type: ignore
26
+ imgs[0].save("/Users/nobr/desk/s3/btc2sim/sims.gif", save_all=True, append_images=imgs[1:], duration=10, loop=0)
27
+
28
+
29
+ def svg_fn(cfg: Config, seq, action, fname, targets=None, fps=2, debug=False):
30
+ # set up and background
31
+ e = esch.Drawing(h=cfg.size, w=cfg.size, row=1, col=seq.pos.shape[0], debug=debug, pad=10)
32
+ esch.grid_fn(repeat(np.array(cfg.map, dtype=float), f"... -> {seq.pos.shape[0]} ...") * 0.5, e, shape="square")
33
+
34
+ # loop thorugh teams
35
+ for i in jnp.unique(cfg.teams): # c#fg.teams.unique():
36
+ col = "red" if i == 1 else "blue"
37
+
38
+ # loop through types
39
+ for j in jnp.unique(cfg.types):
40
+ mask = (cfg.teams == i) & (cfg.types == j)
41
+ size, blast = float(cfg.rules.r[j]), float(cfg.rules.blast[j])
42
+ subset = np.array(rearrange(seq.pos, "a b c d -> a c d b"), dtype=float)[:, mask]
43
+ # print(tree.map(jnp.shape, action), mask.shape)
44
+ sub_action = tree.map(lambda x: x[:, :, mask], action)
45
+ # print(tree.map(jnp.shape, sub_action))
46
+ esch.sims_fn(e, subset, action=sub_action, fps=fps, col=col, stroke=col, size=size, blast=blast)
47
+
48
+ if debug:
49
+ sight, reach = float(cfg.rules.sight[j]), float(cfg.rules.reach[j])
50
+ esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=reach, stroke="grey")
51
+ esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=sight, stroke="yellow")
52
+
53
+ if targets is not None:
54
+ pos = np.array(repeat(targets, f"... -> {seq.pos.shape[0]} ..."))
55
+ arr = np.ones(pos.shape[:-1])
56
+ esch.mesh_fn(e, pos, arr, shape="square", col="purple")
57
+
58
+ # save
59
+ e.dwg.saveas(fname)
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: parabellum
3
+ Version: 0.0.73
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Author-email: Noah Syrkis <desk@syrkis.com>
6
+ Requires-Python: <3.12,>=3.11
7
+ Requires-Dist: brax<0.13,>=0.12.1
8
+ Requires-Dist: cachier<4,>=3.1.2
9
+ Requires-Dist: cartopy<0.24,>=0.23.0
10
+ Requires-Dist: contextily<2,>=1.6.0
11
+ Requires-Dist: distrax<0.2,>=0.1.5
12
+ Requires-Dist: einops<0.9,>=0.8.0
13
+ Requires-Dist: equinox>=0.12.2
14
+ Requires-Dist: evosax<0.2,>=0.1.6
15
+ Requires-Dist: flashbax<0.2,>=0.1.2
16
+ Requires-Dist: flax<0.11,>=0.10.4
17
+ Requires-Dist: folium<0.18,>=0.17.0
18
+ Requires-Dist: geopy<3,>=2.4.1
19
+ Requires-Dist: gymnax<0.0.9,>=0.0.8
20
+ Requires-Dist: ipykernel<7,>=6.29.5
21
+ Requires-Dist: ipython>=8.36.0
22
+ Requires-Dist: jax-tqdm<0.4,>=0.3.1
23
+ Requires-Dist: jax<0.7,>=0.6.0
24
+ Requires-Dist: jaxkd>=0.1.0
25
+ Requires-Dist: jaxtyping<0.3,>=0.2.33
26
+ Requires-Dist: jupyterlab<5,>=4.2.2
27
+ Requires-Dist: navix<0.8,>=0.7.0
28
+ Requires-Dist: notebook>=7.4.2
29
+ Requires-Dist: numpy>=2
30
+ Requires-Dist: omegaconf<3,>=2.3.0
31
+ Requires-Dist: optax<0.3,>=0.2.4
32
+ Requires-Dist: osmnx==2.0.0b0
33
+ Requires-Dist: pandas<3,>=2.2.2
34
+ Requires-Dist: poetry<2,>=1.8.3
35
+ Requires-Dist: rasterio<2,>=1.3.10
36
+ Requires-Dist: stadiamaps<4,>=3.2.1
37
+ Requires-Dist: tensorboard>=2.19.0
38
+ Requires-Dist: tensorflow>=2.19.0
39
+ Requires-Dist: tqdm<5,>=4.66.4
40
+ Requires-Dist: wandb<0.20,>=0.19.7
41
+ Requires-Dist: xprof>=2.20.0
42
+ Description-Content-Type: text/markdown
43
+
44
+ # Parabellum
45
+
46
+ TODO: switch to red and blue team semantics (not enemy and ally)
@@ -0,0 +1,8 @@
1
+ parabellum/__init__.py,sha256=yl3tJXQYNnBgAviJINdHrcALb1177n0CFZ2lGiacbQA,109
2
+ parabellum/env.py,sha256=8KaqUwAmP9FXrINOxGWwZmQJeM6eigeWdjNhlWUNi68,3236
3
+ parabellum/geo.py,sha256=lZ4TQvpfPGW3LV9X-1lfWh3LnwoRCtl_DkN6h26xMng,6922
4
+ parabellum/types.py,sha256=mkBwCA-2_syJAZBeNie9iCG9Cy0SvANMAjCo93XIS1A,3981
5
+ parabellum/utils.py,sha256=wOE2nfcdSvGOBFaGlTEIzS6w451arpT9VSOf_3Zhjwc,2497
6
+ parabellum-0.0.73.dist-info/METADATA,sha256=jHqmHqdrCHM2tIam9bMZ7xwvR1BFLy9bc9qSviKcYAc,1481
7
+ parabellum-0.0.73.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ parabellum-0.0.73.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.0
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,4 +0,0 @@
1
- from .env import Parabellum, Scenario
2
- from .vis import Visualizer
3
-
4
- __all__ = ["Parabellum", "Visualizer", "Scenario"]
@@ -1,296 +0,0 @@
1
- """Parabellum environment based on SMAX"""
2
-
3
- import jax.numpy as jnp
4
- import jax
5
- import numpy as np
6
- from jax import random
7
- from jax import jit
8
- from flax.struct import dataclass
9
- import chex
10
- from jaxmarl.environments.smax.smax_env import State, SMAX
11
- from typing import Tuple, Dict
12
- from functools import partial
13
-
14
-
15
- @dataclass
16
- class Scenario:
17
- """Parabellum scenario"""
18
-
19
- obstacle_coords: chex.Array
20
- obstacle_deltas: chex.Array
21
-
22
- unit_types: chex.Array
23
- num_allies: int
24
- num_enemies: int
25
-
26
- smacv2_position_generation: bool = False
27
- smacv2_unit_type_generation: bool = False
28
-
29
-
30
- # default scenario
31
- scenarios = {
32
- "default": Scenario(
33
- jnp.array([[6, 10], [26, 10]]) * 8,
34
- jnp.array([[0, 12], [0, 1]]) * 8,
35
- jnp.zeros((19,), dtype=jnp.uint8),
36
- 9,
37
- 10,
38
- )
39
- }
40
-
41
-
42
- class Parabellum(SMAX):
43
- def __init__(
44
- self,
45
- scenario: Scenario = scenarios["default"],
46
- unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
47
- **kwargs,
48
- ):
49
- super().__init__(scenario=scenario, **kwargs)
50
- self.unit_type_attack_blasts = unit_type_attack_blasts
51
- self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
52
- self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
53
- self.max_steps = 200
54
- # overwrite supers _world_step method
55
-
56
-
57
- def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
- return state
59
-
60
- def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
- delta_matrix = pos[:, None] - pos[None, :]
62
- dist_matrix = (
63
- jnp.linalg.norm(delta_matrix, axis=-1)
64
- + jnp.identity(self.num_agents)
65
- + 1e-6
66
- )
67
- radius_matrix = (
68
- self.unit_type_radiuses[unit_types][:, None]
69
- + self.unit_type_radiuses[unit_types][None, :]
70
- )
71
- overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
72
- unit_positions = (
73
- pos
74
- + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
- )
76
- return unit_positions
77
-
78
- @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
79
- def _world_step( # modified version of JaxMARL's SMAX _world_step
80
- self,
81
- key: chex.PRNGKey,
82
- state: State,
83
- actions: Tuple[chex.Array, chex.Array],
84
- ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
-
86
- @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
- def inter_fn(pos, new_pos, obs, obs_end):
88
- d1 = jnp.cross(obs - pos, new_pos - pos)
89
- d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
- d3 = jnp.cross(pos - obs, obs_end - obs)
91
- d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
- return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
-
94
- def update_position(idx, vec):
95
- # Compute the movements slightly strangely.
96
- # The velocities below are for diagonal directions
97
- # because these are easier to encode as actions than the four
98
- # diagonal directions. Then rotate the velocity 45
99
- # degrees anticlockwise to compute the movement.
100
- pos = state.unit_positions[idx]
101
- new_pos = (
102
- pos
103
- + vec
104
- * self.unit_type_velocities[state.unit_types[idx]]
105
- * self.time_per_step
106
- )
107
- # avoid going out of bounds
108
- new_pos = jnp.maximum(
109
- jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
110
- jnp.zeros((2,)),
111
- )
112
-
113
- #######################################################################
114
- ############################################ avoid going into obstacles
115
- obs = self.obstacle_coords
116
- obs_end = obs + self.obstacle_deltas
117
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
118
- new_pos = jnp.where(inters, pos, new_pos)
119
-
120
- #######################################################################
121
- #######################################################################
122
-
123
- return new_pos
124
-
125
- #######################################################################
126
- ######################################### units close enough to get hit
127
-
128
- def bystander_fn(attacked_idx):
129
- idxs = jnp.zeros((self.num_agents,))
130
- idxs *= (
131
- jnp.linalg.norm(
132
- state.unit_positions - state.unit_positions[attacked_idx], axis=-1
133
- )
134
- < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
135
- )
136
- return idxs
137
-
138
- #######################################################################
139
- #######################################################################
140
-
141
- def update_agent_health(idx, action, key): # TODO: add attack blasts
142
- # for team 1, their attack actions are labelled in
143
- # reverse order because that is the order they are
144
- # observed in
145
- attacked_idx = jax.lax.cond(
146
- idx < self.num_allies,
147
- lambda: action + self.num_allies - self.num_movement_actions,
148
- lambda: self.num_allies - 1 - (action - self.num_movement_actions),
149
- )
150
- # deal with no-op attack actions (i.e. agents that are moving instead)
151
- attacked_idx = jax.lax.select(
152
- action < self.num_movement_actions, idx, attacked_idx
153
- )
154
-
155
- attack_valid = (
156
- (
157
- jnp.linalg.norm(
158
- state.unit_positions[idx] - state.unit_positions[attacked_idx]
159
- )
160
- < self.unit_type_attack_ranges[state.unit_types[idx]]
161
- )
162
- & state.unit_alive[idx]
163
- & state.unit_alive[attacked_idx]
164
- )
165
- attack_valid = attack_valid & (idx != attacked_idx)
166
- attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
167
- health_diff = jax.lax.select(
168
- attack_valid,
169
- -self.unit_type_attacks[state.unit_types[idx]],
170
- 0.0,
171
- )
172
- # design choice based on the pysc2 randomness details.
173
- # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
174
-
175
- #########################################################
176
- ############################### Add bystander health diff
177
-
178
- bystander_idxs = bystander_fn(attacked_idx) # TODO: use
179
- bystander_valid = (
180
- jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
181
- .astype(jnp.bool_)
182
- .astype(jnp.float32)
183
- )
184
- bystander_health_diff = (
185
- bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
186
- )
187
-
188
- #########################################################
189
- #########################################################
190
-
191
- cooldown_deviation = jax.random.uniform(
192
- key, minval=-self.time_per_step, maxval=2 * self.time_per_step
193
- )
194
- cooldown = (
195
- self.unit_type_weapon_cooldowns[state.unit_types[idx]]
196
- + cooldown_deviation
197
- )
198
- cooldown_diff = jax.lax.select(
199
- attack_valid,
200
- # subtract the current cooldown because we are
201
- # going to add it back. This way we effectively
202
- # set the new cooldown to `cooldown`
203
- cooldown - state.unit_weapon_cooldowns[idx],
204
- -self.time_per_step,
205
- )
206
- return (
207
- health_diff,
208
- attacked_idx,
209
- cooldown_diff,
210
- (bystander_health_diff, bystander_idxs),
211
- )
212
-
213
- def perform_agent_action(idx, action, key):
214
- movement_action, attack_action = action
215
- new_pos = update_position(idx, movement_action)
216
- health_diff, attacked_idxes, cooldown_diff, (bystander) = (
217
- update_agent_health(idx, attack_action, key)
218
- )
219
-
220
- return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
221
-
222
- keys = jax.random.split(key, num=self.num_agents)
223
- pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
224
- perform_agent_action
225
- )(jnp.arange(self.num_agents), actions, keys)
226
-
227
- # checked that no unit passed through an obstacles
228
- new_pos = self._our_push_units_away(pos, state.unit_types)
229
-
230
- # avoid going into obstacles after being pushed
231
- obs = self.obstacle_coords
232
- obs_end = obs + self.obstacle_deltas
233
-
234
- def check_obstacles(pos, new_pos, obs, obs_end):
235
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
- return jnp.where(inters, pos, new_pos)
237
-
238
- pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
-
240
- # Multiple enemies can attack the same unit.
241
- # We have `(health_diff, attacked_idx)` pairs.
242
- # `jax.lax.scatter_add` aggregates these exactly
243
- # in the way we want -- duplicate idxes will have their
244
- # health differences added together. However, it is a
245
- # super thin wrapper around the XLA scatter operation,
246
- # which has this bonkers syntax and requires this dnums
247
- # parameter. The usage here was inferred from a test:
248
- # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
249
- dnums = jax.lax.ScatterDimensionNumbers(
250
- update_window_dims=(),
251
- inserted_window_dims=(0,),
252
- scatter_dims_to_operand_dims=(0,),
253
- )
254
- unit_health = jnp.maximum(
255
- jax.lax.scatter_add(
256
- state.unit_health,
257
- jnp.expand_dims(attacked_idxes, 1),
258
- health_diff,
259
- dnums,
260
- ),
261
- 0.0,
262
- )
263
-
264
- #########################################################
265
- ############################ subtracting bystander health
266
-
267
- _, bystander_health_diff = bystander
268
- unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
269
-
270
- #########################################################
271
- #########################################################
272
-
273
- unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
274
- state = state.replace(
275
- unit_health=unit_health,
276
- unit_positions=pos,
277
- unit_weapon_cooldowns=unit_weapon_cooldowns,
278
- )
279
- return state
280
-
281
- if __name__ == "__main__":
282
- env = Parabellum(map_width=256, map_height=256)
283
- rng, key = random.split(random.PRNGKey(0))
284
- obs, state = env.reset(key)
285
- state_seq = []
286
- for step in range(100):
287
- rng, key = random.split(rng)
288
- key_act = random.split(key, len(env.agents))
289
- actions = {
290
- agent: jax.random.randint(key_act[i], (), 0, 5)
291
- for i, agent in enumerate(env.agents)
292
- }
293
- _, state, _, _, _ = env.step(key, state, actions)
294
- state_seq.append((obs, state, actions))
295
-
296
-