multi-agent-coverage 0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: multi_agent_coverage
3
+ Version: 0.1
4
+ Summary: High-performance batched multi-agent environment
5
+ Author: Your Name
6
+ Dynamic: author
7
+ Dynamic: summary
@@ -0,0 +1,293 @@
1
+ # Multi-Agent Coverage Environment
2
+
3
+ A high-performance batched multi-agent environment built with C++ (pybind11) and OpenMP for fast parallel simulation of agents exploring a grid world.
4
+
5
+ ## Demo
6
+
7
+ ![Multi-Agent Coverage Demo](demo.gif)
8
+
9
+ ## Features
10
+
11
+ - **High-Performance**: ~11.5k FPS for single environment, ~134k FPS for 16 parallel environments
12
+ - **Batched Simulation**: Run multiple independent environments efficiently in parallel
13
+ - **Zero-Copy Memory**: Direct memory sharing between C++ backend and PyTorch tensors
14
+ - **Gymnasium Compatible**: Standard gym.vector.VectorEnv interface
15
+ - **Gravity-Based Attractions**: Query attraction vectors towards map features for each agent
16
+ - **PyGame Visualization**: Real-time rendering of environment state
17
+
18
+ ## Installation
19
+
20
+ ### From Source
21
+
22
+ ```bash
23
+ # Clone repository
24
+ git clone <repository>
25
+ cd craptop
26
+
27
+ # Create virtual environment
28
+ python -m venv .venv
29
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
30
+
31
+ # Install in development mode
32
+ pip install -e .
33
+ ```
34
+
35
+ ### Requirements
36
+
37
+ - Python 3.10+
38
+ - pybind11
39
+ - OpenGL-compatible system (for rendering)
40
+ - GCC/Clang with OpenMP support
41
+
42
+ ## API Reference
43
+
44
+ ### `BatchedGridEnv`
45
+
46
+ High-level gymnasium-compatible wrapper around the C++ environment.
47
+
48
+ #### Constructor
49
+
50
+ ```python
51
+ from env_wrapper import BatchedGridEnv, FeatureType
52
+
53
+ env = BatchedGridEnv(
54
+ num_envs=16, # Number of parallel environments
55
+ n_agents=4, # Agents per environment
56
+ map_size=32, # Grid size (32x32)
57
+ device='cpu', # PyTorch device
58
+ render_mode=None # 'human' for rendering, None for headless
59
+ )
60
+ ```
61
+
62
+ #### Methods
63
+
64
+ ##### `reset(seed=None, options=None)`
65
+ Reset all environments and return observations.
66
+
67
+ ```python
68
+ obs, info = env.reset()
69
+ # obs shape: (num_envs, stride) where stride ≈ 15400 floats
70
+ ```
71
+
72
+ ##### `step(actions)`
73
+ Execute actions and return observations, rewards, and terminal flags.
74
+
75
+ ```python
76
+ actions = np.random.uniform(-1, 1, (num_envs, n_agents, 2))
77
+ obs, rewards, terminated, truncated, info = env.step(actions)
78
+
79
+ # obs shape: (num_envs, stride)
80
+ # rewards shape: (num_envs, n_agents)
81
+ # terminated, truncated: (num_envs,) bool arrays
82
+ ```
83
+
84
+ ##### `get_gravity_attractions(feature_type, agent_mask=None, pow=2)`
85
+ Compute gravity attraction vectors towards a map feature.
86
+
87
+ ```python
88
+ from env_wrapper import FeatureType
89
+
90
+ # Get attractions towards discovered areas
91
+ gravity = env.get_gravity_attractions(
92
+ feature_type=FeatureType.GLOBAL_DISCOVERED,
93
+ agent_mask=None, # None = all agents
94
+ pow=2 # Power parameter
95
+ )
96
+ # Returns torch.Tensor of shape (num_envs, n_agents, 2) with (gx, gy)
97
+ ```
98
+
99
+ **Feature Types:**
100
+ - `FeatureType.EXPECTED_DANGER` - Expected danger map (global, all agents see same)
101
+ - `FeatureType.ACTUAL_DANGER` - True danger map (global)
102
+ - `FeatureType.OBSERVED_DANGER` - Per-agent observed danger map
103
+ - `FeatureType.OBS` - Per-agent observation mask (whether cell was seen)
104
+ - `FeatureType.EXPECTED_OBS` - Per-agent expected observation map
105
+ - `FeatureType.GLOBAL_DISCOVERED` - Global discovery map (shared across agents)
106
+
107
+ **Agent Mask:**
108
+ ```python
109
+ # Only get attractions for first 2 agents
110
+ mask = np.array([True, True, False, False])
111
+ gravity = env.get_gravity_attractions(FeatureType.GLOBAL_DISCOVERED, agent_mask=mask)
112
+ # Masked agents have zero gravity vectors
113
+ ```
114
+
115
+ ##### `render()`
116
+ Render current state to screen (only if `render_mode='human'`).
117
+
118
+ ```python
119
+ env = BatchedGridEnv(num_envs=4, render_mode='human')
120
+ obs, _ = env.reset()
121
+
122
+ while True:
123
+ actions = np.random.uniform(-1, 1, (4, 4, 2))
124
+ obs, r, term, trunc, info = env.step(actions)
125
+ env.render() # Called automatically in step() if render_mode='human'
126
+ ```
127
+
128
+ ##### `close()`
129
+ Clean up resources.
130
+
131
+ ```python
132
+ env.close()
133
+ ```
134
+
135
+ ## Usage Examples
136
+
137
+ ### Basic Loop
138
+
139
+ ```python
140
+ import numpy as np
141
+ from env_wrapper import BatchedGridEnv
142
+
143
+ env = BatchedGridEnv(num_envs=8, n_agents=4)
144
+ obs, _ = env.reset()
145
+
146
+ for step in range(1000):
147
+ # Random policy
148
+ actions = np.random.uniform(-1, 1, (8, 4, 2))
149
+
150
+ obs, rewards, terminated, truncated, info = env.step(actions)
151
+
152
+ # Access observations or rewards
153
+ print(f"Step {step}, Rewards: {rewards}")
154
+
155
+ env.close()
156
+ ```
157
+
158
+ ### With Rendering
159
+
160
+ ```python
161
+ env = BatchedGridEnv(num_envs=1, render_mode='human')
162
+ obs, _ = env.reset()
163
+
164
+ try:
165
+ for _ in range(500):
166
+ actions = np.random.uniform(-1, 1, (1, 4, 2))
167
+ env.step(actions)
168
+ # Render is called automatically
169
+ except KeyboardInterrupt:
170
+ env.close()
171
+ ```
172
+
173
+ ### Gravity-Based Navigation
174
+
175
+ ```python
176
+ from env_wrapper import BatchedGridEnv, FeatureType
177
+
178
+ env = BatchedGridEnv(num_envs=4, n_agents=4)
179
+ obs, _ = env.reset()
180
+
181
+ for step in range(100):
182
+ # Get gravity towards global discovered areas
183
+ gravity = env.get_gravity_attractions(FeatureType.GLOBAL_DISCOVERED, pow=2)
184
+
185
+ # Move toward high gravity (explored areas)
186
+ actions = gravity.numpy() * 0.5
187
+ actions = np.clip(actions, -1, 1)
188
+
189
+ obs, rewards, _, _, _ = env.step(actions)
190
+ print(f"Step {step}, Discovery rewards: {rewards[0]}")
191
+
192
+ env.close()
193
+ ```
194
+
195
+ ### Observation Space Layout
196
+
197
+ The observation is a flattened array with the following structure:
198
+
199
+ ```
200
+ Index Range | Content | Shape
201
+ 0 - 1024 | Expected Danger | (32, 32)
202
+ 1024 - 2048 | Actual Danger | (32, 32)
203
+ 2048 - 6144 | Observed Danger | (4, 32, 32)
204
+ 6144 - 10240 | Observation Mask | (4, 32, 32)
205
+ 10240 - 10248 | Agent Locations | (4, 2) - [y, x] per agent
206
+ 10248 - 14344 | Expected Obs | (4, 32, 32)
207
+ 14344 - 14376 | Last Agent Locations | (4, 2, 4) - history
208
+ 14376 - 15400 | Global Discovered | (32, 32)
209
+ ```
210
+
211
+ Access slices:
212
+ ```python
213
+ import numpy as np
214
+
215
+ obs = obs[0].numpy() # Get first env's observation
216
+ fms = 32 * 32 # FLAT_MAP_SIZE
217
+
218
+ expected_danger = obs[0:fms].reshape(32, 32)
219
+ actual_danger = obs[fms:2*fms].reshape(32, 32)
220
+ agent_locs_offset = (2 + 2*4) * fms
221
+ agent_locations = obs[agent_locs_offset:agent_locs_offset+8].reshape(4, 2)
222
+ discovered = obs[-fms:].reshape(32, 32)
223
+ ```
224
+
225
+ ## Recording Demonstrations
226
+
227
+ Generate an animated GIF of the environment:
228
+
229
+ ```bash
230
+ python gif.py
231
+ ```
232
+
233
+ This creates `demo.gif` showing agents exploring a 32x32 grid. The GIF shows:
234
+ - Green cells: safe, explored areas
235
+ - Red cells: danger, explored areas
236
+ - Dark gray: unexplored cells
237
+ - Light blue dots: agents
238
+
239
+ ## Performance
240
+
241
+ Benchmark results (on typical Linux machine with OpenMP):
242
+
243
+ | Config | FPS |
244
+ |--------|-----|
245
+ | 1 env, 10k frames | ~11,500 |
246
+ | 16 envs, 10k frames | ~134,000 (scaled) |
247
+
248
+ ## Environment Details
249
+
250
+ ### State
251
+ - **Map Size**: 32×32 fixed grid
252
+ - **Agents per Env**: 4 fixed
253
+ - **Agent Speed**: 0.5 cells/step (reduced in danger zones)
254
+ - **View Range**: 3 cells (7×7 view window)
255
+
256
+ ### Rewards
257
+ Agents receive +1.0 reward (divided equally) for discovering new cells. Total discoverable: 1024 cells.
258
+
259
+ ### Dynamics
260
+ - Agent movement is normalized and clamped to map bounds
261
+ - Movement speed is reduced by `1 - 0.8 * danger_level` in each cell
262
+ - Agents cannot see beyond their 7×7view window
263
+
264
+ ## Building from Source
265
+
266
+ The extension requires a C++ compiler with OpenMP:
267
+
268
+ ```bash
269
+ # Install build dependencies
270
+ pip install pybind11 setuptools build
271
+
272
+ # Build in-place for testing
273
+ python setup.py build_ext --inplace
274
+
275
+ # Or use modern build system
276
+ python -m build
277
+ ```
278
+
279
+ ## Publishing to PyPI
280
+
281
+ ```bash
282
+ # Local build and publish
283
+ export PYPI_API_TOKEN="your-token-here"
284
+ ./build_and_publish.sh
285
+
286
+ # Or via GitHub Actions (requires PYPI_API_TOKEN secret):
287
+ git tag v0.1.0
288
+ git push origin v0.1.0
289
+ ```
290
+
291
+ ## License
292
+
293
+ MIT
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: multi_agent_coverage
3
+ Version: 0.1
4
+ Summary: High-performance batched multi-agent environment
5
+ Author: Your Name
6
+ Dynamic: author
7
+ Dynamic: summary
@@ -0,0 +1,9 @@
1
+ README.md
2
+ pyproject.toml
3
+ setup.py
4
+ multi_agent_coverage.egg-info/PKG-INFO
5
+ multi_agent_coverage.egg-info/SOURCES.txt
6
+ multi_agent_coverage.egg-info/dependency_links.txt
7
+ multi_agent_coverage.egg-info/not-zip-safe
8
+ multi_agent_coverage.egg-info/top_level.txt
9
+ src/batched_env.cpp
@@ -0,0 +1 @@
1
+ multi_agent_coverage
@@ -0,0 +1,8 @@
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=61.0",
4
+ "wheel",
5
+ "pybind11>=2.6",
6
+ "build"
7
+ ]
8
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,32 @@
1
+ import sys
2
+ import setuptools
3
+ from pybind11.setup_helpers import Pybind11Extension, build_ext
4
+
5
+ # OpenMP flags depend on the compiler
6
+ c_args = []
7
+ l_args = []
8
+
9
+ if sys.platform == "win32":
10
+ c_args = ['/openmp', '/O2']
11
+ else:
12
+ c_args = ['-fopenmp', '-O3', '-march=native']
13
+ l_args = ['-fopenmp']
14
+
15
+ ext_modules = [
16
+ Pybind11Extension(
17
+ "multi_agent_coverage",
18
+ ["src/batched_env.cpp"],
19
+ extra_compile_args=c_args,
20
+ extra_link_args=l_args,
21
+ ),
22
+ ]
23
+
24
+ setuptools.setup(
25
+ name="multi_agent_coverage",
26
+ version="0.1",
27
+ author="Your Name",
28
+ description="High-performance batched multi-agent environment",
29
+ ext_modules=ext_modules,
30
+ cmdclass={"build_ext": build_ext},
31
+ zip_safe=False,
32
+ )
@@ -0,0 +1,473 @@
1
+ #include <pybind11/pybind11.h>
2
+ #include <pybind11/numpy.h>
3
+ #include <vector>
4
+ #include <cmath>
5
+ #include <algorithm>
6
+ #include <cstring>
7
+ #include <iostream>
8
+ #include <random>
9
+ #include <omp.h>
10
+
11
+ namespace py = pybind11;
12
+
13
+ #define MAP_SIZE 32
14
+ #define N_AGENTS 4
15
+ #define VIEW_RANGE 3
16
+ #define SPEED 0.5f
17
+ #define DANGER_PENALTY_FACTOR 0.8f
18
+
19
+ // Feature type enum
20
+ enum FeatureType {
21
+ EXPECTED_DANGER_FEATURE = 0,
22
+ ACTUAL_DANGER_FEATURE = 1,
23
+ OBSERVED_DANGER_FEATURE = 2,
24
+ OBS_FEATURE = 3,
25
+ EXPECTED_OBS_FEATURE = 4,
26
+ GLOBAL_DISCOVERED_FEATURE = 5,
27
+ OTHER_AGENTS_FEATURE = 6,
28
+ OTHER_AGENTS_LAST_KNOWN_FEATURE = 7,
29
+ GLOBAL_UNDISCOVERED_FEATURE = 8,
30
+ OBS_UNDISCOVERED_FEATURE = 9,
31
+ EXPECTED_OBS_UNDISCOVERED_FEATURE = 10
32
+ };
33
+
34
+ constexpr int FLAT_MAP_SIZE = (MAP_SIZE * MAP_SIZE);
35
+
36
+ // Total floats per environment
37
+ constexpr int ENV_STRIDE =
38
+ FLAT_MAP_SIZE + // expected_danger (0)
39
+ FLAT_MAP_SIZE + // actual_danger (1024)
40
+ (N_AGENTS * FLAT_MAP_SIZE) + // observed_danger (2048)
41
+ (N_AGENTS * FLAT_MAP_SIZE) + // obs (mask) (6144)
42
+ (N_AGENTS * 2) + // agent_locations (10240)
43
+ (N_AGENTS * FLAT_MAP_SIZE) + // expected_obs (10248)
44
+ (N_AGENTS * 2 * N_AGENTS) + // last_agent_locations (14344)
45
+ FLAT_MAP_SIZE; // global_discovered (14376)
46
+ // Total size approx 15400 floats per env
47
+
48
+ struct GameStateView {
49
+ float* expected_danger;
50
+ float* actual_danger;
51
+ float* observed_danger;
52
+ float* obs;
53
+ float* agent_locations;
54
+ float* expected_obs;
55
+ float* last_agent_locations;
56
+ float* global_discovered;
57
+ };
58
+
59
+ // Helper function to compute gravity from other agent positions
60
+ // Treats each agent as a point mass with mass=1.0
61
+ void get_gravity_from_agents(float* all_agent_locations, int current_agent_idx, int pow,
62
+ float agent_x, float agent_y, float& out_gx, float& out_gy) {
63
+ out_gx = 0.0f;
64
+ out_gy = 0.0f;
65
+
66
+ for (int j = 0; j < N_AGENTS; ++j) {
67
+ if (j == current_agent_idx) continue; // Skip self
68
+
69
+ float other_x = all_agent_locations[j * 2 + 1];
70
+ float other_y = all_agent_locations[j * 2];
71
+
72
+ // Vector from current agent TO other agent
73
+ float dx = other_x - agent_x;
74
+ float dy = other_y - agent_y;
75
+
76
+ float dist_sq = dx * dx + dy * dy;
77
+ float dist = std::sqrt(dist_sq);
78
+
79
+ if (dist < 0.1f) continue; // Min distance clamp
80
+
81
+ // Force magnitude = 1.0 / dist^pow (treating other agent as mass=1.0)
82
+ float denom;
83
+ if (pow == 2) denom = dist * dist_sq; // dist^3
84
+ else if (pow == 1) denom = dist_sq; // dist^2
85
+ else denom = std::pow(dist, pow + 1);
86
+
87
+ float force = 1.0f / denom;
88
+
89
+ out_gx += dx * force;
90
+ out_gy += dy * force;
91
+ }
92
+ }
93
+
94
+ // Returns vector: Sum(Mass / dist^pow * direction_vector)
95
+ // returns dx, dy
96
+ // If invert=true, uses (1.0 - map[i]) as mass (for undiscovered features)
97
+ void get_gravity(float* map, int pow, float agent_x, float agent_y, float& out_gx, float& out_gy, bool invert = false) {
98
+ out_gx = 0.0f;
99
+ out_gy = 0.0f;
100
+
101
+ for (int y = 0; y < MAP_SIZE; ++y) {
102
+ for (int x = 0; x < MAP_SIZE; ++x) {
103
+ float mass = invert ? (1.0f - map[y * MAP_SIZE + x]) : map[y * MAP_SIZE + x];
104
+ if (mass <= 0.001f) continue; // Optimization: Ignore empty cells
105
+
106
+ // Vector from Agent TO Cell
107
+ float dx = (float)x - agent_x;
108
+ float dy = (float)y - agent_y;
109
+
110
+ float dist_sq = dx * dx + dy * dy;
111
+ float dist = std::sqrt(dist_sq);
112
+
113
+ if (dist < 0.1f) continue; // Min distance clamp
114
+
115
+ // Force magnitude = Mass / dist^pow
116
+ // Since we need to multiply by normalized direction (dx/dist),
117
+ // We actually divide by dist^(pow+1)
118
+
119
+ float denom;
120
+ if (pow == 2) denom = dist * dist_sq; // dist^3
121
+ else if (pow == 1) denom = dist_sq; // dist^2
122
+ else denom = std::pow(dist, pow + 1);
123
+
124
+ float force = mass / denom;
125
+
126
+ out_gx += dx * force;
127
+ out_gy += dy * force;
128
+ }
129
+ }
130
+ }
131
+
132
+ class BatchedEnvironment {
133
+ public:
134
+ int num_envs;
135
+ int seed;
136
+ std::vector<std::mt19937> rngs;
137
+ std::vector<float> data;
138
+
139
+ BatchedEnvironment(int n_envs, int sim_seed) : num_envs(n_envs), seed(sim_seed) {
140
+ data.resize(num_envs * ENV_STRIDE);
141
+ rngs.resize(num_envs);
142
+ for(int i=0; i<num_envs; ++i) {
143
+ rngs[i].seed(seed + i);
144
+ }
145
+ reset();
146
+ }
147
+
148
+ void bind_state(GameStateView& s, int env_idx) {
149
+ float* ptr = data.data() + (env_idx * ENV_STRIDE);
150
+ s.expected_danger = ptr; ptr += FLAT_MAP_SIZE;
151
+ s.actual_danger = ptr; ptr += FLAT_MAP_SIZE;
152
+ s.observed_danger = ptr; ptr += (N_AGENTS * FLAT_MAP_SIZE);
153
+ s.obs = ptr; ptr += (N_AGENTS * FLAT_MAP_SIZE);
154
+ s.agent_locations = ptr; ptr += (N_AGENTS * 2);
155
+ s.expected_obs = ptr; ptr += (N_AGENTS * FLAT_MAP_SIZE);
156
+ s.last_agent_locations = ptr; ptr += (N_AGENTS * 2 * N_AGENTS);
157
+ s.global_discovered = ptr;
158
+ }
159
+
160
+ void reset() {
161
+ // Parallel reset
162
+ #pragma omp parallel for
163
+ for (int e = 0; e < num_envs; ++e) {
164
+ GameStateView s;
165
+ bind_state(s, e);
166
+
167
+ // Zero out memory for this env
168
+ std::memset(data.data() + (e * ENV_STRIDE), 0, ENV_STRIDE * sizeof(float));
169
+
170
+ // Procedural Map Gen (Thread-safe RNG is tricky, using simple math here)
171
+ for (int i = 0; i < FLAT_MAP_SIZE; ++i) {
172
+ int y = i / MAP_SIZE;
173
+ int x = i % MAP_SIZE;
174
+ // Deterministic pseudo-random based on env index
175
+ // Changed to [-1, 1] range
176
+ float val = (std::sin(x * 0.3f + e) + std::cos(y * 0.3f + e*2)) / 2.0f;
177
+ s.actual_danger[i] = std::fmin(1.0f, std::fmax(-1.0f, val));
178
+ s.expected_danger[i] = 0.0f; // Midpoint of [-1, 1]
179
+ }
180
+
181
+ for (int i = 0; i < N_AGENTS; ++i) {
182
+ s.agent_locations[i * 2] = MAP_SIZE / 2.0f;
183
+ s.agent_locations[i * 2 + 1] = MAP_SIZE / 2.0f;
184
+ }
185
+
186
+ // Initialize observed_danger with expected_danger for all agents
187
+ for (int i = 0; i < N_AGENTS; ++i) {
188
+ for (int j = 0; j < FLAT_MAP_SIZE; ++j) {
189
+ s.observed_danger[i * FLAT_MAP_SIZE + j] = s.expected_danger[j];
190
+ }
191
+ }
192
+
193
+ update_obs(s);
194
+ }
195
+ }
196
+
197
+ py::array_t<float> step(py::array_t<float> actions_array, float communication_prob = -1.0f) {
198
+ auto r = actions_array.unchecked<2>();
199
+ py::array_t<float> rewards_array({num_envs, N_AGENTS});
200
+ auto rewards_ptr = rewards_array.mutable_unchecked<2>();
201
+
202
+ // Parallel Step
203
+ #pragma omp parallel for
204
+ for (int e = 0; e < num_envs; ++e) {
205
+ GameStateView s;
206
+ bind_state(s, e);
207
+
208
+ update_locations(s, r, e);
209
+ update_obs(s);
210
+ update_last_location(s, e, communication_prob);
211
+ calc_rewards(s, rewards_ptr, e);
212
+ }
213
+
214
+ return rewards_array;
215
+ }
216
+
217
+ // Return (memory_ptr, size_bytes) for Python ctypes/torch
218
+ std::pair<size_t, size_t> get_memory_view() {
219
+ return { (size_t)data.data(), data.size() * sizeof(float) };
220
+ }
221
+
222
+ int get_stride() { return ENV_STRIDE; }
223
+ int get_flat_map_size() { return FLAT_MAP_SIZE; }
224
+
225
+ py::array_t<float> get_gravity_attractions(
226
+ py::object agent_mask_obj,
227
+ int feature_type,
228
+ int pow = 2,
229
+ bool normalize = false)
230
+ {
231
+ // Parse agent mask (None means all agents)
232
+ std::vector<bool> agent_mask(N_AGENTS, true);
233
+ if (!agent_mask_obj.is_none()) {
234
+ auto mask = agent_mask_obj.cast<py::array_t<bool>>();
235
+ auto mask_ptr = mask.data();
236
+ for (int i = 0; i < N_AGENTS; ++i) {
237
+ agent_mask[i] = mask_ptr[i];
238
+ }
239
+ }
240
+
241
+ // Allocate output: (num_envs, N_AGENTS, 2)
242
+ py::array_t<float> output_array({num_envs, N_AGENTS, 2});
243
+ auto output_ptr = output_array.mutable_unchecked<3>();
244
+
245
+ #pragma omp parallel for
246
+ for (int e = 0; e < num_envs; ++e) {
247
+ GameStateView s;
248
+ bind_state(s, e);
249
+
250
+ for (int i = 0; i < N_AGENTS; ++i) {
251
+ if (!agent_mask[i]) {
252
+ output_ptr(e, i, 0) = 0.0f;
253
+ output_ptr(e, i, 1) = 0.0f;
254
+ continue;
255
+ }
256
+
257
+ float agent_x = s.agent_locations[i * 2 + 1];
258
+ float agent_y = s.agent_locations[i * 2];
259
+ float gx, gy;
260
+
261
+ float* feature_map = nullptr;
262
+ bool invert = false;
263
+
264
+ if (feature_type == EXPECTED_DANGER_FEATURE) {
265
+ feature_map = s.expected_danger;
266
+ } else if (feature_type == ACTUAL_DANGER_FEATURE) {
267
+ feature_map = s.actual_danger;
268
+ } else if (feature_type == OBSERVED_DANGER_FEATURE) {
269
+ feature_map = s.observed_danger + (i * FLAT_MAP_SIZE);
270
+ } else if (feature_type == OBS_FEATURE) {
271
+ feature_map = s.obs + (i * FLAT_MAP_SIZE);
272
+ } else if (feature_type == EXPECTED_OBS_FEATURE) {
273
+ feature_map = s.expected_obs + (i * FLAT_MAP_SIZE);
274
+ } else if (feature_type == GLOBAL_DISCOVERED_FEATURE) {
275
+ feature_map = s.global_discovered;
276
+ } else if (feature_type == GLOBAL_UNDISCOVERED_FEATURE) {
277
+ feature_map = s.global_discovered;
278
+ invert = true; // Use (1.0 - discovered)
279
+ } else if (feature_type == OBS_UNDISCOVERED_FEATURE) {
280
+ feature_map = s.obs + (i * FLAT_MAP_SIZE);
281
+ invert = true; // Use (1.0 - obs)
282
+ } else if (feature_type == EXPECTED_OBS_UNDISCOVERED_FEATURE) {
283
+ feature_map = s.expected_obs + (i * FLAT_MAP_SIZE);
284
+ invert = true; // Use (1.0 - expected_obs)
285
+ } else if (feature_type == OTHER_AGENTS_FEATURE) {
286
+ // Use actual current locations of other agents
287
+ get_gravity_from_agents(s.agent_locations, i, pow, agent_x, agent_y, gx, gy);
288
+ feature_map = nullptr; // Signal that we already computed gravity
289
+ } else if (feature_type == OTHER_AGENTS_LAST_KNOWN_FEATURE) {
290
+ // Use this agent's last known locations of other agents
291
+ float* agent_i_last_known = s.last_agent_locations + (i * 2 * N_AGENTS);
292
+ get_gravity_from_agents(agent_i_last_known, i, pow, agent_x, agent_y, gx, gy);
293
+ feature_map = nullptr; // Signal that we already computed gravity
294
+ }
295
+
296
+ if (feature_map) {
297
+ get_gravity(feature_map, pow, agent_x, agent_y, gx, gy, invert);
298
+ } else if (feature_type != OTHER_AGENTS_FEATURE &&
299
+ feature_type != OTHER_AGENTS_LAST_KNOWN_FEATURE) {
300
+ gx = gy = 0.0f;
301
+ }
302
+
303
+ // Normalize to max norm 1.0 if requested
304
+ if (normalize) {
305
+ float mag = std::sqrt(gx * gx + gy * gy);
306
+ if (mag > 1.0f) {
307
+ gx /= mag;
308
+ gy /= mag;
309
+ }
310
+ }
311
+
312
+ output_ptr(e, i, 0) = gy; // dy
313
+ output_ptr(e, i, 1) = gx; // dx
314
+ }
315
+ }
316
+
317
+ return output_array;
318
+ }
319
+
320
+ private:
321
+ void update_locations(GameStateView& s, const py::detail::unchecked_reference<float, 2>& actions, int env_idx) {
322
+ for (int i = 0; i < N_AGENTS; ++i) {
323
+ float dy = actions(env_idx, i * 2);
324
+ float dx = actions(env_idx, i * 2 + 1);
325
+ float len = std::sqrt(dy * dy + dx * dx);
326
+ if (len > 0.0001f) { dy /= len; dx /= len; }
327
+
328
+ int cy = (int)s.agent_locations[i * 2];
329
+ int cx = (int)s.agent_locations[i * 2 + 1];
330
+
331
+ cy = std::max(0, std::min(MAP_SIZE - 1, cy));
332
+ cx = std::max(0, std::min(MAP_SIZE - 1, cx));
333
+
334
+ float danger = s.actual_danger[cy * MAP_SIZE + cx];
335
+ float effective_speed = SPEED * (1.0f - (danger * DANGER_PENALTY_FACTOR));
336
+
337
+ float ny = s.agent_locations[i * 2] + dy * effective_speed;
338
+ float nx = s.agent_locations[i * 2 + 1] + dx * effective_speed;
339
+
340
+ s.agent_locations[i * 2] = std::fmax(0.0f, std::fmin((float)MAP_SIZE - 0.01f, ny));
341
+ s.agent_locations[i * 2 + 1] = std::fmax(0.0f, std::fmin((float)MAP_SIZE - 0.01f, nx));
342
+ }
343
+ }
344
+
345
+ void update_obs(GameStateView& s) {
346
+ for (int i = 0; i < N_AGENTS; ++i) {
347
+ int yc = (int)s.agent_locations[i * 2];
348
+ int xc = (int)s.agent_locations[i * 2 + 1];
349
+
350
+ int y_s = std::max(0, yc - VIEW_RANGE);
351
+ int y_e = std::min(MAP_SIZE, yc + VIEW_RANGE + 1);
352
+ int x_s = std::max(0, xc - VIEW_RANGE);
353
+ int x_e = std::min(MAP_SIZE, xc + VIEW_RANGE + 1);
354
+
355
+ for (int ly = y_s; ly < y_e; ++ly) {
356
+ for (int lx = x_s; lx < x_e; ++lx) {
357
+ int idx = ly * MAP_SIZE + lx;
358
+ int agent_idx = i * FLAT_MAP_SIZE + idx;
359
+ s.obs[agent_idx] = 1.0f;
360
+ s.observed_danger[agent_idx] = s.actual_danger[idx];
361
+ }
362
+ }
363
+ }
364
+ }
365
+
366
+ void update_last_location(GameStateView& s, int env_idx, float p) {
367
+ std::uniform_real_distribution<float> dist(0.0f, 1.0f);
368
+ bool try_radio = (p > 0.0f && p <= 1.0f);
369
+
370
+ // For each agent i (viewer/row agent), check if other agents j (target/col agent)
371
+ // are within view range, and update their last known location if so
372
+ for (int i = 0; i < N_AGENTS; ++i) {
373
+ int viewer_y = (int)s.agent_locations[i * 2];
374
+ int viewer_x = (int)s.agent_locations[i * 2 + 1];
375
+
376
+ for (int j = 0; j < N_AGENTS; ++j) {
377
+ // If it's myself, always update (agents know where they are)
378
+ if (i == j) {
379
+ s.last_agent_locations[i * (2 * N_AGENTS) + j * 2] = s.agent_locations[j * 2];
380
+ s.last_agent_locations[i * (2 * N_AGENTS) + j * 2 + 1] = s.agent_locations[j * 2 + 1];
381
+ continue;
382
+ }
383
+
384
+ int target_y = (int)s.agent_locations[j * 2];
385
+ int target_x = (int)s.agent_locations[j * 2 + 1];
386
+
387
+ bool updated = false;
388
+
389
+ // Check physical view range (Visual)
390
+ if (std::abs(viewer_y - target_y) <= VIEW_RANGE &&
391
+ std::abs(viewer_x - target_x) <= VIEW_RANGE) {
392
+ updated = true;
393
+ }
394
+
395
+ // Check radio communication (Probabilistic)
396
+ // "agents should have a probability 'p' of sharing their location to another agent"
397
+ if (!updated && try_radio) {
398
+ // Agent j shares location with agent i with probability p
399
+ if (dist(rngs[env_idx]) < p) {
400
+ updated = true;
401
+ }
402
+ }
403
+
404
+ if (updated) {
405
+ // Update agent i's knowledge of agent j's location
406
+ s.last_agent_locations[i * (2 * N_AGENTS) + j * 2] = s.agent_locations[j * 2];
407
+ s.last_agent_locations[i * (2 * N_AGENTS) + j * 2 + 1] = s.agent_locations[j * 2 + 1];
408
+ }
409
+ }
410
+ }
411
+ }
412
+
413
+ void calc_rewards(GameStateView& s, py::detail::unchecked_mutable_reference<float, 2>& rewards, int env_idx) {
414
+ for(int i=0; i<N_AGENTS; ++i) rewards(env_idx, i) = 0.0f;
415
+
416
+ for (int y = 0; y < MAP_SIZE; ++y) {
417
+ for (int x = 0; x < MAP_SIZE; ++x) {
418
+ int idx = y * MAP_SIZE + x;
419
+ if (s.global_discovered[idx] > 0.5f) continue;
420
+
421
+ int seeing_count = 0;
422
+ bool seen_by[N_AGENTS] = {false}; // Fixed size array initialization
423
+
424
+ for (int i = 0; i < N_AGENTS; ++i) {
425
+ int ay = (int)s.agent_locations[i * 2];
426
+ int ax = (int)s.agent_locations[i * 2 + 1];
427
+ if (std::abs(ay - y) <= VIEW_RANGE && std::abs(ax - x) <= VIEW_RANGE) {
428
+ seen_by[i] = true;
429
+ seeing_count++;
430
+ }
431
+ }
432
+
433
+ if (seeing_count > 0) {
434
+ s.global_discovered[idx] = 1.0f;
435
+ float share = 1.0f / (float)seeing_count;
436
+ for (int i = 0; i < N_AGENTS; ++i) {
437
+ if (seen_by[i]) rewards(env_idx, i) += share;
438
+ }
439
+ }
440
+ }
441
+ }
442
+ }
443
+ };
444
+
445
+ PYBIND11_MODULE(multi_agent_coverage, m) {
446
+ // Feature type enum
447
+ py::enum_<FeatureType>(m, "FeatureType")
448
+ .value("EXPECTED_DANGER", EXPECTED_DANGER_FEATURE)
449
+ .value("ACTUAL_DANGER", ACTUAL_DANGER_FEATURE)
450
+ .value("OBSERVED_DANGER", OBSERVED_DANGER_FEATURE)
451
+ .value("OBS", OBS_FEATURE)
452
+ .value("EXPECTED_OBS", EXPECTED_OBS_FEATURE)
453
+ .value("GLOBAL_DISCOVERED", GLOBAL_DISCOVERED_FEATURE)
454
+ .value("OTHER_AGENTS", OTHER_AGENTS_FEATURE)
455
+ .value("OTHER_AGENTS_LAST_KNOWN", OTHER_AGENTS_LAST_KNOWN_FEATURE)
456
+ .value("GLOBAL_UNDISCOVERED", GLOBAL_UNDISCOVERED_FEATURE)
457
+ .value("OBS_UNDISCOVERED", OBS_UNDISCOVERED_FEATURE)
458
+ .value("EXPECTED_OBS_UNDISCOVERED", EXPECTED_OBS_UNDISCOVERED_FEATURE);
459
+
460
+ py::class_<BatchedEnvironment>(m, "BatchedEnvironment")
461
+ .def(py::init<int, int>(), py::arg("n_envs"), py::arg("seed") = 42)
462
+ .def("reset", &BatchedEnvironment::reset)
463
+ .def("step", &BatchedEnvironment::step, py::arg("actions"), py::arg("communication_prob") = -1.0f)
464
+ .def("get_memory_view", &BatchedEnvironment::get_memory_view)
465
+ .def("get_stride", &BatchedEnvironment::get_stride)
466
+ .def("get_flat_map_size", &BatchedEnvironment::get_flat_map_size)
467
+ .def("get_gravity_attractions", &BatchedEnvironment::get_gravity_attractions,
468
+ py::arg("agent_mask") = py::none(),
469
+ py::arg("feature_type"),
470
+ py::arg("pow") = 2,
471
+ py::arg("normalize") = false)
472
+ .def_readonly("num_envs", &BatchedEnvironment::num_envs);
473
+ }