rl-interrogate 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rl_interrogate-0.1.0/.gitignore +22 -0
- rl_interrogate-0.1.0/LICENSE +21 -0
- rl_interrogate-0.1.0/PKG-INFO +153 -0
- rl_interrogate-0.1.0/README.md +124 -0
- rl_interrogate-0.1.0/configs/default_probe_config.yaml +9 -0
- rl_interrogate-0.1.0/examples/formation_flight_probe.py +160 -0
- rl_interrogate-0.1.0/examples/halfcheetah_ablation.py +221 -0
- rl_interrogate-0.1.0/pyproject.toml +42 -0
- rl_interrogate-0.1.0/src/rl_interrogate/__init__.py +25 -0
- rl_interrogate-0.1.0/src/rl_interrogate/ablation.py +106 -0
- rl_interrogate-0.1.0/src/rl_interrogate/env_wrappers.py +147 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/__init__.py +35 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/control_runner.py +203 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/data_models.py +86 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/dataset_builder.py +208 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/direction_extractor.py +336 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/hyperplane_reflector.py +106 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/model_loader.py +120 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/output_analyzer.py +245 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/probe_trainer.py +147 -0
- rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/results_writer.py +295 -0
- rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_ablation.py +174 -0
- rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_pid.py +254 -0
- rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_probing.py +226 -0
- rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_training.py +343 -0
- rl_interrogate-0.1.0/src/rl_interrogate/multi_reward_halfcheetah.py +152 -0
- rl_interrogate-0.1.0/src/rl_interrogate/natural_harmful_utils.py +280 -0
- rl_interrogate-0.1.0/src/rl_interrogate/patching.py +75 -0
- rl_interrogate-0.1.0/src/rl_interrogate/pca_utils.py +38 -0
- rl_interrogate-0.1.0/src/rl_interrogate/polarity.py +180 -0
- rl_interrogate-0.1.0/src/rl_interrogate/policies/__init__.py +13 -0
- rl_interrogate-0.1.0/src/rl_interrogate/policies/recurrent_policy.py +211 -0
- rl_interrogate-0.1.0/src/rl_interrogate/probing.py +289 -0
- rl_interrogate-0.1.0/src/rl_interrogate/visualization.py +125 -0
- rl_interrogate-0.1.0/tests/__init__.py +1 -0
- rl_interrogate-0.1.0/tests/test_hardening.py +249 -0
- rl_interrogate-0.1.0/tests/test_llm_polarity_integration.py +340 -0
- rl_interrogate-0.1.0/tests/test_llm_polarity_properties.py +317 -0
- rl_interrogate-0.1.0/tests/test_llm_polarity_units.py +763 -0
- rl_interrogate-0.1.0/tests/test_multi_objective_ablation.py +303 -0
- rl_interrogate-0.1.0/tests/test_multi_objective_pid.py +326 -0
- rl_interrogate-0.1.0/tests/test_multi_objective_probing.py +379 -0
- rl_interrogate-0.1.0/tests/test_multi_objective_probing_properties.py +418 -0
- rl_interrogate-0.1.0/tests/test_multi_objective_properties.py +517 -0
- rl_interrogate-0.1.0/tests/test_multi_objective_training.py +353 -0
- rl_interrogate-0.1.0/tests/test_multi_reward_halfcheetah.py +226 -0
- rl_interrogate-0.1.0/tests/test_multi_reward_properties.py +86 -0
- rl_interrogate-0.1.0/tests/test_natural_harmful.py +389 -0
- rl_interrogate-0.1.0/tests/test_natural_harmful_properties.py +405 -0
- rl_interrogate-0.1.0/tests/test_recurrent_policy.py +249 -0
- rl_interrogate-0.1.0/tests/test_recurrent_properties.py +301 -0
- rl_interrogate-0.1.0/tests/test_roundtrip.py +1201 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
|
|
5
|
+
# Environments
|
|
6
|
+
/.venv
|
|
7
|
+
/.env
|
|
8
|
+
miniconda/
|
|
9
|
+
miniconda3/
|
|
10
|
+
Miniconda3-latest-Linux-x86_64.sh
|
|
11
|
+
|
|
12
|
+
# Jupyter
|
|
13
|
+
.ipynb_checkpoints/
|
|
14
|
+
|
|
15
|
+
# OS
|
|
16
|
+
.DS_Store
|
|
17
|
+
Thumbs.db
|
|
18
|
+
|
|
19
|
+
# IDEs
|
|
20
|
+
.vscode/
|
|
21
|
+
.idea/
|
|
22
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Ansh Arora
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rl-interrogate
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Mechanistic interpretability toolkit for RL policies
|
|
5
|
+
Project-URL: Homepage, https://github.com/aroransh/rl_interrogate
|
|
6
|
+
Project-URL: Repository, https://github.com/aroransh/rl_interrogate
|
|
7
|
+
Project-URL: Documentation, https://github.com/aroransh/rl_interrogate#readme
|
|
8
|
+
Author: Ansh Arora
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Requires-Python: >=3.9
|
|
16
|
+
Requires-Dist: gymnasium
|
|
17
|
+
Requires-Dist: matplotlib
|
|
18
|
+
Requires-Dist: numpy
|
|
19
|
+
Requires-Dist: scikit-learn
|
|
20
|
+
Requires-Dist: seaborn
|
|
21
|
+
Requires-Dist: stable-baselines3
|
|
22
|
+
Requires-Dist: torch>=2.0
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: hypothesis; extra == 'dev'
|
|
25
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
26
|
+
Provides-Extra: mujoco
|
|
27
|
+
Requires-Dist: mujoco; extra == 'mujoco'
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# rl_interrogate
|
|
31
|
+
|
|
32
|
+
Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate
|
|
33
|
+
what your policy has learned — not just how well it performs.
|
|
34
|
+
|
|
35
|
+
## Installation
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install -e .
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Dependencies: `torch`, `numpy`, `scikit-learn`, `matplotlib`, `seaborn`, `gymnasium`,
|
|
42
|
+
`stable-baselines3`. MuJoCo environments require the `mujoco` extra:
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
pip install -e ".[mujoco]"
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Minimal Example
|
|
49
|
+
|
|
50
|
+
Load a checkpoint, run a probe, run ablation:
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
import torch
|
|
54
|
+
import numpy as np
|
|
55
|
+
from rl_interrogate import LinearProbe, AblationHook
|
|
56
|
+
|
|
57
|
+
# 1. Load your policy network (any torch.nn.Sequential)
|
|
58
|
+
policy_net = torch.load("my_policy.pt")
|
|
59
|
+
policy_net.eval()
|
|
60
|
+
|
|
61
|
+
# 2. Build a synthetic observation grid
|
|
62
|
+
obs_grid = np.random.randn(500, 28).astype(np.float32)
|
|
63
|
+
labels = obs_grid[:, 10] # probe for lateral position
|
|
64
|
+
|
|
65
|
+
# 3. Linear probe at layer 5
|
|
66
|
+
probe = LinearProbe()
|
|
67
|
+
probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
|
|
68
|
+
print(f"Layer 5 R² = {probe.score():.4f}")
|
|
69
|
+
|
|
70
|
+
# 4. Ablation: zero the probe direction, measure performance change
|
|
71
|
+
hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
|
|
72
|
+
with hook.apply(alpha=0.0):
|
|
73
|
+
# run your environment here — the probe direction is zeroed
|
|
74
|
+
pass
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
## Experiments
|
|
78
|
+
|
|
79
|
+
The library was developed for the WakeRider paper (TMLR submission). Key experiments:
|
|
80
|
+
|
|
81
|
+
- **Formation flight probe** (`examples/formation_flight_probe.py`): Reproduces
|
|
82
|
+
Actor L5 R²=0.973 from the seed-42 checkpoint.
|
|
83
|
+
- **HalfCheetah ablation** (`examples/halfcheetah_ablation.py`): Runs ablation on
|
|
84
|
+
a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.
|
|
85
|
+
|
|
86
|
+
## API Reference
|
|
87
|
+
|
|
88
|
+
### Probing
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
from rl_interrogate import LinearProbe, MLPProbe, LassoProbe
|
|
92
|
+
|
|
93
|
+
# Ridge regression probe (recommended)
|
|
94
|
+
probe = LinearProbe()
|
|
95
|
+
probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
|
|
96
|
+
r2 = probe.score()
|
|
97
|
+
|
|
98
|
+
# MLP probe (non-linear)
|
|
99
|
+
mlp_probe = MLPProbe()
|
|
100
|
+
mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
|
|
101
|
+
|
|
102
|
+
# Sparse Lasso probe
|
|
103
|
+
lasso = LassoProbe()
|
|
104
|
+
lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
|
|
105
|
+
r2, n_nonzero = lasso.score()
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### Ablation
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
from rl_interrogate import AblationHook
|
|
112
|
+
|
|
113
|
+
hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
|
|
114
|
+
with hook.apply(alpha=0.0): # alpha=0 zeros the direction
|
|
115
|
+
rewards = run_episodes(model, env, n=100)
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
### PCA Utilities
|
|
119
|
+
|
|
120
|
+
```python
|
|
121
|
+
from rl_interrogate import fit_pca, project_subspace
|
|
122
|
+
|
|
123
|
+
pca = fit_pca(activations, n_components=20)
|
|
124
|
+
acts_k = project_subspace(activations, pca, k=1) # rank-1 projection
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
### Visualization
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
from rl_interrogate import plot_probe_heatmap, plot_ablation_curve
|
|
131
|
+
|
|
132
|
+
plot_probe_heatmap(activations, labels, title="Layer 5 probe")
|
|
133
|
+
plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
## Running Tests
|
|
137
|
+
|
|
138
|
+
```bash
|
|
139
|
+
pytest rl_interrogate/tests/ -v
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
## Link to Paper
|
|
143
|
+
|
|
144
|
+
This library implements the interrogation protocol described in:
|
|
145
|
+
|
|
146
|
+
> *WakeRider: Emergent V-Formation Flight via Wake Exploitation*
|
|
147
|
+
> Section 3.3: The rl_interrogate Library
|
|
148
|
+
|
|
149
|
+
The protocol consists of four steps:
|
|
150
|
+
1. **Linear probing** — fit Ridge regression from hidden activations to a field label
|
|
151
|
+
2. **Polarity inversion** — negate the sensor; verify R² drops (causal, not correlational)
|
|
152
|
+
3. **Single-direction ablation** — zero the probe direction; measure performance change
|
|
153
|
+
4. **Subspace variance** — greedy PCA selection to find the minimal sufficient subspace
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# rl_interrogate
|
|
2
|
+
|
|
3
|
+
Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate
|
|
4
|
+
what your policy has learned — not just how well it performs.
|
|
5
|
+
|
|
6
|
+
## Installation
|
|
7
|
+
|
|
8
|
+
```bash
|
|
9
|
+
pip install -e .
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
Dependencies: `torch`, `numpy`, `scikit-learn`, `matplotlib`, `seaborn`, `gymnasium`,
|
|
13
|
+
`stable-baselines3`. MuJoCo environments require the `mujoco` extra:
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
pip install -e ".[mujoco]"
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Minimal Example
|
|
20
|
+
|
|
21
|
+
Load a checkpoint, run a probe, run ablation:
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
import torch
|
|
25
|
+
import numpy as np
|
|
26
|
+
from rl_interrogate import LinearProbe, AblationHook
|
|
27
|
+
|
|
28
|
+
# 1. Load your policy network (any torch.nn.Sequential)
|
|
29
|
+
policy_net = torch.load("my_policy.pt")
|
|
30
|
+
policy_net.eval()
|
|
31
|
+
|
|
32
|
+
# 2. Build a synthetic observation grid
|
|
33
|
+
obs_grid = np.random.randn(500, 28).astype(np.float32)
|
|
34
|
+
labels = obs_grid[:, 10] # probe for lateral position
|
|
35
|
+
|
|
36
|
+
# 3. Linear probe at layer 5
|
|
37
|
+
probe = LinearProbe()
|
|
38
|
+
probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
|
|
39
|
+
print(f"Layer 5 R² = {probe.score():.4f}")
|
|
40
|
+
|
|
41
|
+
# 4. Ablation: zero the probe direction, measure performance change
|
|
42
|
+
hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
|
|
43
|
+
with hook.apply(alpha=0.0):
|
|
44
|
+
# run your environment here — the probe direction is zeroed
|
|
45
|
+
pass
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Experiments
|
|
49
|
+
|
|
50
|
+
The library was developed for the WakeRider paper (TMLR submission). Key experiments:
|
|
51
|
+
|
|
52
|
+
- **Formation flight probe** (`examples/formation_flight_probe.py`): Reproduces
|
|
53
|
+
Actor L5 R²=0.973 from the seed-42 checkpoint.
|
|
54
|
+
- **HalfCheetah ablation** (`examples/halfcheetah_ablation.py`): Runs ablation on
|
|
55
|
+
a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.
|
|
56
|
+
|
|
57
|
+
## API Reference
|
|
58
|
+
|
|
59
|
+
### Probing
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from rl_interrogate import LinearProbe, MLPProbe, LassoProbe
|
|
63
|
+
|
|
64
|
+
# Ridge regression probe (recommended)
|
|
65
|
+
probe = LinearProbe()
|
|
66
|
+
probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
|
|
67
|
+
r2 = probe.score()
|
|
68
|
+
|
|
69
|
+
# MLP probe (non-linear)
|
|
70
|
+
mlp_probe = MLPProbe()
|
|
71
|
+
mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
|
|
72
|
+
|
|
73
|
+
# Sparse Lasso probe
|
|
74
|
+
lasso = LassoProbe()
|
|
75
|
+
lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
|
|
76
|
+
r2, n_nonzero = lasso.score()
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
### Ablation
|
|
80
|
+
|
|
81
|
+
```python
|
|
82
|
+
from rl_interrogate import AblationHook
|
|
83
|
+
|
|
84
|
+
hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
|
|
85
|
+
with hook.apply(alpha=0.0): # alpha=0 zeros the direction
|
|
86
|
+
rewards = run_episodes(model, env, n=100)
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
### PCA Utilities
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from rl_interrogate import fit_pca, project_subspace
|
|
93
|
+
|
|
94
|
+
pca = fit_pca(activations, n_components=20)
|
|
95
|
+
acts_k = project_subspace(activations, pca, k=1) # rank-1 projection
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
### Visualization
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
from rl_interrogate import plot_probe_heatmap, plot_ablation_curve
|
|
102
|
+
|
|
103
|
+
plot_probe_heatmap(activations, labels, title="Layer 5 probe")
|
|
104
|
+
plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
## Running Tests
|
|
108
|
+
|
|
109
|
+
```bash
|
|
110
|
+
pytest rl_interrogate/tests/ -v
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Link to Paper
|
|
114
|
+
|
|
115
|
+
This library implements the interrogation protocol described in:
|
|
116
|
+
|
|
117
|
+
> *WakeRider: Emergent V-Formation Flight via Wake Exploitation*
|
|
118
|
+
> Section 3.3: The rl_interrogate Library
|
|
119
|
+
|
|
120
|
+
The protocol consists of four steps:
|
|
121
|
+
1. **Linear probing** — fit Ridge regression from hidden activations to a field label
|
|
122
|
+
2. **Polarity inversion** — negate the sensor; verify R² drops (causal, not correlational)
|
|
123
|
+
3. **Single-direction ablation** — zero the probe direction; measure performance change
|
|
124
|
+
4. **Subspace variance** — greedy PCA selection to find the minimal sufficient subspace
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Example: Actor L5 probe on formation flight seed-42 checkpoint.
|
|
2
|
+
|
|
3
|
+
Demonstrates reproducing R²=0.9730 from the seed-42 checkpoint using
|
|
4
|
+
LinearProbe from rl_interrogate.
|
|
5
|
+
|
|
6
|
+
Usage::
|
|
7
|
+
|
|
8
|
+
python rl_interrogate/examples/formation_flight_probe.py
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
21
|
+
|
|
22
|
+
from rl_interrogate import LinearProbe
|
|
23
|
+
|
|
24
|
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
25
|
+
CHECKPOINT_PATH = "runs/formation/formation_seed_42/checkpoints/best_agent.pt"
|
|
26
|
+
LAYER_IDX = 5
|
|
27
|
+
EXPECTED_R2 = 0.9730
|
|
28
|
+
TOLERANCE = 0.01
|
|
29
|
+
|
|
30
|
+
OBS_PER_FOLLOWER = 28
|
|
31
|
+
ACT_PER_FOLLOWER = 4
|
|
32
|
+
DRONE_WINGSPAN_M = 1.5
|
|
33
|
+
DRONE_WEIGHT_N = 2.0 * 9.81
|
|
34
|
+
AIRSPEED_MS = 15.0
|
|
35
|
+
TARGET_ALTITUDE = 10.0
|
|
36
|
+
|
|
37
|
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _build_actor_net(obs_dim: int, state_dict: dict) -> nn.Sequential:
|
|
41
|
+
net = nn.Sequential(
|
|
42
|
+
nn.Linear(obs_dim, 256), nn.ELU(),
|
|
43
|
+
nn.Linear(256, 256), nn.ELU(),
|
|
44
|
+
nn.Linear(256, 128), nn.ELU(),
|
|
45
|
+
nn.Linear(128, ACT_PER_FOLLOWER),
|
|
46
|
+
).to(DEVICE)
|
|
47
|
+
net_sd = {k[len("net."):]: v for k, v in state_dict.items() if k.startswith("net.")}
|
|
48
|
+
if net_sd:
|
|
49
|
+
net.load_state_dict(net_sd, strict=True)
|
|
50
|
+
net.eval()
|
|
51
|
+
return net
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _build_obs_grid(obs_dim: int, raw_ckpt: dict):
|
|
55
|
+
"""Build the 80×120 synthetic observation grid and true upwash labels."""
|
|
56
|
+
try:
|
|
57
|
+
from source.drone_formation.physics.wake_model_gpu import compute_upwash_field_batched
|
|
58
|
+
except ImportError:
|
|
59
|
+
raise ImportError(
|
|
60
|
+
"source.drone_formation is required to build the obs grid. "
|
|
61
|
+
"Run this script from the workspace root."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
X_RANGE = np.linspace(-6.0, 0.0, 80)
|
|
65
|
+
Y_RANGE = np.linspace(-3.0, 3.0, 120)
|
|
66
|
+
Z_VAL = TARGET_ALTITUDE
|
|
67
|
+
lead_pos_np = np.array([0.0, 0.0, Z_VAL])
|
|
68
|
+
|
|
69
|
+
grid_x, grid_y = np.meshgrid(X_RANGE, Y_RANGE, indexing="ij")
|
|
70
|
+
N = grid_x.size
|
|
71
|
+
|
|
72
|
+
foll_pos_t = torch.tensor(
|
|
73
|
+
np.stack([
|
|
74
|
+
grid_x.ravel().astype(np.float32),
|
|
75
|
+
grid_y.ravel().astype(np.float32),
|
|
76
|
+
np.full(N, Z_VAL, dtype=np.float32),
|
|
77
|
+
], axis=1),
|
|
78
|
+
device=DEVICE,
|
|
79
|
+
).unsqueeze(1)
|
|
80
|
+
lead_pos_t = torch.tensor(lead_pos_np, dtype=torch.float32, device=DEVICE).unsqueeze(0).expand(N, -1)
|
|
81
|
+
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
upwash_t = compute_upwash_field_batched(
|
|
84
|
+
foll_pos_t, lead_pos_t,
|
|
85
|
+
torch.full((N,), DRONE_WINGSPAN_M, device=DEVICE),
|
|
86
|
+
torch.full((N,), DRONE_WEIGHT_N, device=DEVICE),
|
|
87
|
+
airspeed=AIRSPEED_MS, device=DEVICE,
|
|
88
|
+
)
|
|
89
|
+
true_upwash = upwash_t[:, 0].cpu().numpy()
|
|
90
|
+
|
|
91
|
+
# Load scaler stats
|
|
92
|
+
scaler_mean = scaler_var = None
|
|
93
|
+
if "state_preprocessor" in raw_ckpt:
|
|
94
|
+
sp = raw_ckpt["state_preprocessor"]
|
|
95
|
+
if "running_mean" in sp and "running_variance" in sp:
|
|
96
|
+
scaler_mean = sp["running_mean"][:obs_dim].to(DEVICE)
|
|
97
|
+
scaler_var = sp["running_variance"][:obs_dim].to(DEVICE)
|
|
98
|
+
|
|
99
|
+
def _normalize(x: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
if scaler_mean is not None:
|
|
101
|
+
return ((x - scaler_mean) / (torch.sqrt(scaler_var) + 1e-8)).float()
|
|
102
|
+
return x.float()
|
|
103
|
+
|
|
104
|
+
obs_batch = torch.zeros(N, obs_dim, dtype=torch.float32, device=DEVICE)
|
|
105
|
+
obs_batch[:, 9] = foll_pos_t[:, 0, 0] - lead_pos_t[:, 0]
|
|
106
|
+
obs_batch[:, 10] = foll_pos_t[:, 0, 1] - lead_pos_t[:, 1]
|
|
107
|
+
obs_batch[:, 11] = foll_pos_t[:, 0, 2] - lead_pos_t[:, 2]
|
|
108
|
+
obs_batch[:, 14] = foll_pos_t[:, 0, 0]
|
|
109
|
+
obs_batch[:, 15] = foll_pos_t[:, 0, 1]
|
|
110
|
+
obs_batch[:, 16] = foll_pos_t[:, 0, 2]
|
|
111
|
+
obs_norm = _normalize(obs_batch)
|
|
112
|
+
|
|
113
|
+
return obs_norm, true_upwash
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def main():
|
|
117
|
+
# ── Load checkpoint ───────────────────────────────────────────────────────
|
|
118
|
+
if not os.path.exists(CHECKPOINT_PATH):
|
|
119
|
+
raise FileNotFoundError(
|
|
120
|
+
f"Checkpoint not found: {CHECKPOINT_PATH}\n"
|
|
121
|
+
"Please ensure the formation flight training has completed for seed 42.\n"
|
|
122
|
+
"Run: python scripts/train_sac.py --seed 42 (or the PPO training script)"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
print(f"Loading checkpoint: {CHECKPOINT_PATH}")
|
|
126
|
+
raw_ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
|
|
127
|
+
policy_sd = raw_ckpt.get("policy", raw_ckpt)
|
|
128
|
+
|
|
129
|
+
# Detect obs dim from checkpoint
|
|
130
|
+
w0 = policy_sd.get("net.0.weight")
|
|
131
|
+
obs_dim = int(w0.shape[1]) if w0 is not None else OBS_PER_FOLLOWER
|
|
132
|
+
print(f"Detected obs dim: {obs_dim}")
|
|
133
|
+
|
|
134
|
+
# ── Build actor network ───────────────────────────────────────────────────
|
|
135
|
+
actor_net = _build_actor_net(obs_dim, policy_sd)
|
|
136
|
+
|
|
137
|
+
# ── Build obs grid ────────────────────────────────────────────────────────
|
|
138
|
+
print("Building 80×120 observation grid (9600 points)...")
|
|
139
|
+
obs_norm, true_upwash = _build_obs_grid(obs_dim, raw_ckpt)
|
|
140
|
+
|
|
141
|
+
# ── Run LinearProbe at Actor L5 ───────────────────────────────────────────
|
|
142
|
+
print(f"Running LinearProbe at layer {LAYER_IDX}...")
|
|
143
|
+
probe = LinearProbe()
|
|
144
|
+
probe.fit(actor_net, LAYER_IDX, obs_norm, true_upwash, test_size=0.2, seed=42)
|
|
145
|
+
r2 = probe.score()
|
|
146
|
+
|
|
147
|
+
print(f"\nActor L5 Ridge R² = {r2:.4f}")
|
|
148
|
+
print(f"Expected R² = {EXPECTED_R2:.4f}")
|
|
149
|
+
print(f"Difference = {abs(r2 - EXPECTED_R2):.4f}")
|
|
150
|
+
|
|
151
|
+
if abs(r2 - EXPECTED_R2) <= TOLERANCE:
|
|
152
|
+
print(f"✓ R² matches expected value within tolerance ({TOLERANCE})")
|
|
153
|
+
else:
|
|
154
|
+
print(f"✗ WARNING: R² differs from expected by more than {TOLERANCE}")
|
|
155
|
+
|
|
156
|
+
return r2
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
if __name__ == "__main__":
|
|
160
|
+
main()
|