rl-interrogate 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. rl_interrogate-0.1.0/.gitignore +22 -0
  2. rl_interrogate-0.1.0/LICENSE +21 -0
  3. rl_interrogate-0.1.0/PKG-INFO +153 -0
  4. rl_interrogate-0.1.0/README.md +124 -0
  5. rl_interrogate-0.1.0/configs/default_probe_config.yaml +9 -0
  6. rl_interrogate-0.1.0/examples/formation_flight_probe.py +160 -0
  7. rl_interrogate-0.1.0/examples/halfcheetah_ablation.py +221 -0
  8. rl_interrogate-0.1.0/pyproject.toml +42 -0
  9. rl_interrogate-0.1.0/src/rl_interrogate/__init__.py +25 -0
  10. rl_interrogate-0.1.0/src/rl_interrogate/ablation.py +106 -0
  11. rl_interrogate-0.1.0/src/rl_interrogate/env_wrappers.py +147 -0
  12. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/__init__.py +35 -0
  13. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/control_runner.py +203 -0
  14. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/data_models.py +86 -0
  15. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/dataset_builder.py +208 -0
  16. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/direction_extractor.py +336 -0
  17. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/hyperplane_reflector.py +106 -0
  18. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/model_loader.py +120 -0
  19. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/output_analyzer.py +245 -0
  20. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/probe_trainer.py +147 -0
  21. rl_interrogate-0.1.0/src/rl_interrogate/llm_polarity/results_writer.py +295 -0
  22. rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_ablation.py +174 -0
  23. rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_pid.py +254 -0
  24. rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_probing.py +226 -0
  25. rl_interrogate-0.1.0/src/rl_interrogate/multi_objective_training.py +343 -0
  26. rl_interrogate-0.1.0/src/rl_interrogate/multi_reward_halfcheetah.py +152 -0
  27. rl_interrogate-0.1.0/src/rl_interrogate/natural_harmful_utils.py +280 -0
  28. rl_interrogate-0.1.0/src/rl_interrogate/patching.py +75 -0
  29. rl_interrogate-0.1.0/src/rl_interrogate/pca_utils.py +38 -0
  30. rl_interrogate-0.1.0/src/rl_interrogate/polarity.py +180 -0
  31. rl_interrogate-0.1.0/src/rl_interrogate/policies/__init__.py +13 -0
  32. rl_interrogate-0.1.0/src/rl_interrogate/policies/recurrent_policy.py +211 -0
  33. rl_interrogate-0.1.0/src/rl_interrogate/probing.py +289 -0
  34. rl_interrogate-0.1.0/src/rl_interrogate/visualization.py +125 -0
  35. rl_interrogate-0.1.0/tests/__init__.py +1 -0
  36. rl_interrogate-0.1.0/tests/test_hardening.py +249 -0
  37. rl_interrogate-0.1.0/tests/test_llm_polarity_integration.py +340 -0
  38. rl_interrogate-0.1.0/tests/test_llm_polarity_properties.py +317 -0
  39. rl_interrogate-0.1.0/tests/test_llm_polarity_units.py +763 -0
  40. rl_interrogate-0.1.0/tests/test_multi_objective_ablation.py +303 -0
  41. rl_interrogate-0.1.0/tests/test_multi_objective_pid.py +326 -0
  42. rl_interrogate-0.1.0/tests/test_multi_objective_probing.py +379 -0
  43. rl_interrogate-0.1.0/tests/test_multi_objective_probing_properties.py +418 -0
  44. rl_interrogate-0.1.0/tests/test_multi_objective_properties.py +517 -0
  45. rl_interrogate-0.1.0/tests/test_multi_objective_training.py +353 -0
  46. rl_interrogate-0.1.0/tests/test_multi_reward_halfcheetah.py +226 -0
  47. rl_interrogate-0.1.0/tests/test_multi_reward_properties.py +86 -0
  48. rl_interrogate-0.1.0/tests/test_natural_harmful.py +389 -0
  49. rl_interrogate-0.1.0/tests/test_natural_harmful_properties.py +405 -0
  50. rl_interrogate-0.1.0/tests/test_recurrent_policy.py +249 -0
  51. rl_interrogate-0.1.0/tests/test_recurrent_properties.py +301 -0
  52. rl_interrogate-0.1.0/tests/test_roundtrip.py +1201 -0
@@ -0,0 +1,22 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+
5
+ # Environments
6
+ /.venv
7
+ /.env
8
+ miniconda/
9
+ miniconda3/
10
+ Miniconda3-latest-Linux-x86_64.sh
11
+
12
+ # Jupyter
13
+ .ipynb_checkpoints/
14
+
15
+ # OS
16
+ .DS_Store
17
+ Thumbs.db
18
+
19
+ # IDEs
20
+ .vscode/
21
+ .idea/
22
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Ansh Arora
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,153 @@
1
+ Metadata-Version: 2.4
2
+ Name: rl-interrogate
3
+ Version: 0.1.0
4
+ Summary: Mechanistic interpretability toolkit for RL policies
5
+ Project-URL: Homepage, https://github.com/aroransh/rl_interrogate
6
+ Project-URL: Repository, https://github.com/aroransh/rl_interrogate
7
+ Project-URL: Documentation, https://github.com/aroransh/rl_interrogate#readme
8
+ Author: Ansh Arora
9
+ License-Expression: MIT
10
+ License-File: LICENSE
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Requires-Python: >=3.9
16
+ Requires-Dist: gymnasium
17
+ Requires-Dist: matplotlib
18
+ Requires-Dist: numpy
19
+ Requires-Dist: scikit-learn
20
+ Requires-Dist: seaborn
21
+ Requires-Dist: stable-baselines3
22
+ Requires-Dist: torch>=2.0
23
+ Provides-Extra: dev
24
+ Requires-Dist: hypothesis; extra == 'dev'
25
+ Requires-Dist: pytest; extra == 'dev'
26
+ Provides-Extra: mujoco
27
+ Requires-Dist: mujoco; extra == 'mujoco'
28
+ Description-Content-Type: text/markdown
29
+
30
+ # rl_interrogate
31
+
32
+ Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate
33
+ what your policy has learned — not just how well it performs.
34
+
35
+ ## Installation
36
+
37
+ ```bash
38
+ pip install -e .
39
+ ```
40
+
41
+ Dependencies: `torch`, `numpy`, `scikit-learn`, `matplotlib`, `seaborn`, `gymnasium`,
42
+ `stable-baselines3`. MuJoCo environments require the `mujoco` extra:
43
+
44
+ ```bash
45
+ pip install -e ".[mujoco]"
46
+ ```
47
+
48
+ ## Minimal Example
49
+
50
+ Load a checkpoint, run a probe, run ablation:
51
+
52
+ ```python
53
+ import torch
54
+ import numpy as np
55
+ from rl_interrogate import LinearProbe, AblationHook
56
+
57
+ # 1. Load your policy network (any torch.nn.Sequential)
58
+ policy_net = torch.load("my_policy.pt")
59
+ policy_net.eval()
60
+
61
+ # 2. Build a synthetic observation grid
62
+ obs_grid = np.random.randn(500, 28).astype(np.float32)
63
+ labels = obs_grid[:, 10] # probe for lateral position
64
+
65
+ # 3. Linear probe at layer 5
66
+ probe = LinearProbe()
67
+ probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
68
+ print(f"Layer 5 R² = {probe.score():.4f}")
69
+
70
+ # 4. Ablation: zero the probe direction, measure performance change
71
+ hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
72
+ with hook.apply(alpha=0.0):
73
+ # run your environment here — the probe direction is zeroed
74
+ pass
75
+ ```
76
+
77
+ ## Experiments
78
+
79
+ The library was developed for the WakeRider paper (TMLR submission). Key experiments:
80
+
81
+ - **Formation flight probe** (`examples/formation_flight_probe.py`): Reproduces
82
+ Actor L5 R²=0.973 from the seed-42 checkpoint.
83
+ - **HalfCheetah ablation** (`examples/halfcheetah_ablation.py`): Runs ablation on
84
+ a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.
85
+
86
+ ## API Reference
87
+
88
+ ### Probing
89
+
90
+ ```python
91
+ from rl_interrogate import LinearProbe, MLPProbe, LassoProbe
92
+
93
+ # Ridge regression probe (recommended)
94
+ probe = LinearProbe()
95
+ probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
96
+ r2 = probe.score()
97
+
98
+ # MLP probe (non-linear)
99
+ mlp_probe = MLPProbe()
100
+ mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
101
+
102
+ # Sparse Lasso probe
103
+ lasso = LassoProbe()
104
+ lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
105
+ r2, n_nonzero = lasso.score()
106
+ ```
107
+
108
+ ### Ablation
109
+
110
+ ```python
111
+ from rl_interrogate import AblationHook
112
+
113
+ hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
114
+ with hook.apply(alpha=0.0): # alpha=0 zeros the direction
115
+ rewards = run_episodes(model, env, n=100)
116
+ ```
117
+
118
+ ### PCA Utilities
119
+
120
+ ```python
121
+ from rl_interrogate import fit_pca, project_subspace
122
+
123
+ pca = fit_pca(activations, n_components=20)
124
+ acts_k = project_subspace(activations, pca, k=1) # rank-1 projection
125
+ ```
126
+
127
+ ### Visualization
128
+
129
+ ```python
130
+ from rl_interrogate import plot_probe_heatmap, plot_ablation_curve
131
+
132
+ plot_probe_heatmap(activations, labels, title="Layer 5 probe")
133
+ plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])
134
+ ```
135
+
136
+ ## Running Tests
137
+
138
+ ```bash
139
+ pytest rl_interrogate/tests/ -v
140
+ ```
141
+
142
+ ## Link to Paper
143
+
144
+ This library implements the interrogation protocol described in:
145
+
146
+ > *WakeRider: Emergent V-Formation Flight via Wake Exploitation*
147
+ > Section 3.3: The rl_interrogate Library
148
+
149
+ The protocol consists of four steps:
150
+ 1. **Linear probing** — fit Ridge regression from hidden activations to a field label
151
+ 2. **Polarity inversion** — negate the sensor; verify R² drops (causal, not correlational)
152
+ 3. **Single-direction ablation** — zero the probe direction; measure performance change
153
+ 4. **Subspace variance** — greedy PCA selection to find the minimal sufficient subspace
@@ -0,0 +1,124 @@
1
+ # rl_interrogate
2
+
3
+ Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate
4
+ what your policy has learned — not just how well it performs.
5
+
6
+ ## Installation
7
+
8
+ ```bash
9
+ pip install -e .
10
+ ```
11
+
12
+ Dependencies: `torch`, `numpy`, `scikit-learn`, `matplotlib`, `seaborn`, `gymnasium`,
13
+ `stable-baselines3`. MuJoCo environments require the `mujoco` extra:
14
+
15
+ ```bash
16
+ pip install -e ".[mujoco]"
17
+ ```
18
+
19
+ ## Minimal Example
20
+
21
+ Load a checkpoint, run a probe, run ablation:
22
+
23
+ ```python
24
+ import torch
25
+ import numpy as np
26
+ from rl_interrogate import LinearProbe, AblationHook
27
+
28
+ # 1. Load your policy network (any torch.nn.Sequential)
29
+ policy_net = torch.load("my_policy.pt")
30
+ policy_net.eval()
31
+
32
+ # 2. Build a synthetic observation grid
33
+ obs_grid = np.random.randn(500, 28).astype(np.float32)
34
+ labels = obs_grid[:, 10] # probe for lateral position
35
+
36
+ # 3. Linear probe at layer 5
37
+ probe = LinearProbe()
38
+ probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
39
+ print(f"Layer 5 R² = {probe.score():.4f}")
40
+
41
+ # 4. Ablation: zero the probe direction, measure performance change
42
+ hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
43
+ with hook.apply(alpha=0.0):
44
+ # run your environment here — the probe direction is zeroed
45
+ pass
46
+ ```
47
+
48
+ ## Experiments
49
+
50
+ The library was developed for the WakeRider paper (TMLR submission). Key experiments:
51
+
52
+ - **Formation flight probe** (`examples/formation_flight_probe.py`): Reproduces
53
+ Actor L5 R²=0.973 from the seed-42 checkpoint.
54
+ - **HalfCheetah ablation** (`examples/halfcheetah_ablation.py`): Runs ablation on
55
+ a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.
56
+
57
+ ## API Reference
58
+
59
+ ### Probing
60
+
61
+ ```python
62
+ from rl_interrogate import LinearProbe, MLPProbe, LassoProbe
63
+
64
+ # Ridge regression probe (recommended)
65
+ probe = LinearProbe()
66
+ probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
67
+ r2 = probe.score()
68
+
69
+ # MLP probe (non-linear)
70
+ mlp_probe = MLPProbe()
71
+ mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
72
+
73
+ # Sparse Lasso probe
74
+ lasso = LassoProbe()
75
+ lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
76
+ r2, n_nonzero = lasso.score()
77
+ ```
78
+
79
+ ### Ablation
80
+
81
+ ```python
82
+ from rl_interrogate import AblationHook
83
+
84
+ hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
85
+ with hook.apply(alpha=0.0): # alpha=0 zeros the direction
86
+ rewards = run_episodes(model, env, n=100)
87
+ ```
88
+
89
+ ### PCA Utilities
90
+
91
+ ```python
92
+ from rl_interrogate import fit_pca, project_subspace
93
+
94
+ pca = fit_pca(activations, n_components=20)
95
+ acts_k = project_subspace(activations, pca, k=1) # rank-1 projection
96
+ ```
97
+
98
+ ### Visualization
99
+
100
+ ```python
101
+ from rl_interrogate import plot_probe_heatmap, plot_ablation_curve
102
+
103
+ plot_probe_heatmap(activations, labels, title="Layer 5 probe")
104
+ plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])
105
+ ```
106
+
107
+ ## Running Tests
108
+
109
+ ```bash
110
+ pytest rl_interrogate/tests/ -v
111
+ ```
112
+
113
+ ## Link to Paper
114
+
115
+ This library implements the interrogation protocol described in:
116
+
117
+ > *WakeRider: Emergent V-Formation Flight via Wake Exploitation*
118
+ > Section 3.3: The rl_interrogate Library
119
+
120
+ The protocol consists of four steps:
121
+ 1. **Linear probing** — fit Ridge regression from hidden activations to a field label
122
+ 2. **Polarity inversion** — negate the sensor; verify R² drops (causal, not correlational)
123
+ 3. **Single-direction ablation** — zero the probe direction; measure performance change
124
+ 4. **Subspace variance** — greedy PCA selection to find the minimal sufficient subspace
@@ -0,0 +1,9 @@
1
+ probe_alpha: 1.0
2
+ test_split: 0.2
3
+ random_state: 42
4
+ n_boot: 1000
5
+ pca_k:
6
+ - 1
7
+ - 5
8
+ - 10
9
+ - 20
@@ -0,0 +1,160 @@
1
+ """Example: Actor L5 probe on formation flight seed-42 checkpoint.
2
+
3
+ Demonstrates reproducing R²=0.9730 from the seed-42 checkpoint using
4
+ LinearProbe from rl_interrogate.
5
+
6
+ Usage::
7
+
8
+ python rl_interrogate/examples/formation_flight_probe.py
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import sys
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
21
+
22
+ from rl_interrogate import LinearProbe
23
+
24
+ # ── Constants ─────────────────────────────────────────────────────────────────
25
+ CHECKPOINT_PATH = "runs/formation/formation_seed_42/checkpoints/best_agent.pt"
26
+ LAYER_IDX = 5
27
+ EXPECTED_R2 = 0.9730
28
+ TOLERANCE = 0.01
29
+
30
+ OBS_PER_FOLLOWER = 28
31
+ ACT_PER_FOLLOWER = 4
32
+ DRONE_WINGSPAN_M = 1.5
33
+ DRONE_WEIGHT_N = 2.0 * 9.81
34
+ AIRSPEED_MS = 15.0
35
+ TARGET_ALTITUDE = 10.0
36
+
37
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+
40
+ def _build_actor_net(obs_dim: int, state_dict: dict) -> nn.Sequential:
41
+ net = nn.Sequential(
42
+ nn.Linear(obs_dim, 256), nn.ELU(),
43
+ nn.Linear(256, 256), nn.ELU(),
44
+ nn.Linear(256, 128), nn.ELU(),
45
+ nn.Linear(128, ACT_PER_FOLLOWER),
46
+ ).to(DEVICE)
47
+ net_sd = {k[len("net."):]: v for k, v in state_dict.items() if k.startswith("net.")}
48
+ if net_sd:
49
+ net.load_state_dict(net_sd, strict=True)
50
+ net.eval()
51
+ return net
52
+
53
+
54
+ def _build_obs_grid(obs_dim: int, raw_ckpt: dict):
55
+ """Build the 80×120 synthetic observation grid and true upwash labels."""
56
+ try:
57
+ from source.drone_formation.physics.wake_model_gpu import compute_upwash_field_batched
58
+ except ImportError:
59
+ raise ImportError(
60
+ "source.drone_formation is required to build the obs grid. "
61
+ "Run this script from the workspace root."
62
+ )
63
+
64
+ X_RANGE = np.linspace(-6.0, 0.0, 80)
65
+ Y_RANGE = np.linspace(-3.0, 3.0, 120)
66
+ Z_VAL = TARGET_ALTITUDE
67
+ lead_pos_np = np.array([0.0, 0.0, Z_VAL])
68
+
69
+ grid_x, grid_y = np.meshgrid(X_RANGE, Y_RANGE, indexing="ij")
70
+ N = grid_x.size
71
+
72
+ foll_pos_t = torch.tensor(
73
+ np.stack([
74
+ grid_x.ravel().astype(np.float32),
75
+ grid_y.ravel().astype(np.float32),
76
+ np.full(N, Z_VAL, dtype=np.float32),
77
+ ], axis=1),
78
+ device=DEVICE,
79
+ ).unsqueeze(1)
80
+ lead_pos_t = torch.tensor(lead_pos_np, dtype=torch.float32, device=DEVICE).unsqueeze(0).expand(N, -1)
81
+
82
+ with torch.no_grad():
83
+ upwash_t = compute_upwash_field_batched(
84
+ foll_pos_t, lead_pos_t,
85
+ torch.full((N,), DRONE_WINGSPAN_M, device=DEVICE),
86
+ torch.full((N,), DRONE_WEIGHT_N, device=DEVICE),
87
+ airspeed=AIRSPEED_MS, device=DEVICE,
88
+ )
89
+ true_upwash = upwash_t[:, 0].cpu().numpy()
90
+
91
+ # Load scaler stats
92
+ scaler_mean = scaler_var = None
93
+ if "state_preprocessor" in raw_ckpt:
94
+ sp = raw_ckpt["state_preprocessor"]
95
+ if "running_mean" in sp and "running_variance" in sp:
96
+ scaler_mean = sp["running_mean"][:obs_dim].to(DEVICE)
97
+ scaler_var = sp["running_variance"][:obs_dim].to(DEVICE)
98
+
99
+ def _normalize(x: torch.Tensor) -> torch.Tensor:
100
+ if scaler_mean is not None:
101
+ return ((x - scaler_mean) / (torch.sqrt(scaler_var) + 1e-8)).float()
102
+ return x.float()
103
+
104
+ obs_batch = torch.zeros(N, obs_dim, dtype=torch.float32, device=DEVICE)
105
+ obs_batch[:, 9] = foll_pos_t[:, 0, 0] - lead_pos_t[:, 0]
106
+ obs_batch[:, 10] = foll_pos_t[:, 0, 1] - lead_pos_t[:, 1]
107
+ obs_batch[:, 11] = foll_pos_t[:, 0, 2] - lead_pos_t[:, 2]
108
+ obs_batch[:, 14] = foll_pos_t[:, 0, 0]
109
+ obs_batch[:, 15] = foll_pos_t[:, 0, 1]
110
+ obs_batch[:, 16] = foll_pos_t[:, 0, 2]
111
+ obs_norm = _normalize(obs_batch)
112
+
113
+ return obs_norm, true_upwash
114
+
115
+
116
+ def main():
117
+ # ── Load checkpoint ───────────────────────────────────────────────────────
118
+ if not os.path.exists(CHECKPOINT_PATH):
119
+ raise FileNotFoundError(
120
+ f"Checkpoint not found: {CHECKPOINT_PATH}\n"
121
+ "Please ensure the formation flight training has completed for seed 42.\n"
122
+ "Run: python scripts/train_sac.py --seed 42 (or the PPO training script)"
123
+ )
124
+
125
+ print(f"Loading checkpoint: {CHECKPOINT_PATH}")
126
+ raw_ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
127
+ policy_sd = raw_ckpt.get("policy", raw_ckpt)
128
+
129
+ # Detect obs dim from checkpoint
130
+ w0 = policy_sd.get("net.0.weight")
131
+ obs_dim = int(w0.shape[1]) if w0 is not None else OBS_PER_FOLLOWER
132
+ print(f"Detected obs dim: {obs_dim}")
133
+
134
+ # ── Build actor network ───────────────────────────────────────────────────
135
+ actor_net = _build_actor_net(obs_dim, policy_sd)
136
+
137
+ # ── Build obs grid ────────────────────────────────────────────────────────
138
+ print("Building 80×120 observation grid (9600 points)...")
139
+ obs_norm, true_upwash = _build_obs_grid(obs_dim, raw_ckpt)
140
+
141
+ # ── Run LinearProbe at Actor L5 ───────────────────────────────────────────
142
+ print(f"Running LinearProbe at layer {LAYER_IDX}...")
143
+ probe = LinearProbe()
144
+ probe.fit(actor_net, LAYER_IDX, obs_norm, true_upwash, test_size=0.2, seed=42)
145
+ r2 = probe.score()
146
+
147
+ print(f"\nActor L5 Ridge R² = {r2:.4f}")
148
+ print(f"Expected R² = {EXPECTED_R2:.4f}")
149
+ print(f"Difference = {abs(r2 - EXPECTED_R2):.4f}")
150
+
151
+ if abs(r2 - EXPECTED_R2) <= TOLERANCE:
152
+ print(f"✓ R² matches expected value within tolerance ({TOLERANCE})")
153
+ else:
154
+ print(f"✗ WARNING: R² differs from expected by more than {TOLERANCE}")
155
+
156
+ return r2
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()