alignment-risk 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alignment_risk-0.1.3/LICENSE +21 -0
- alignment_risk-0.1.3/MANIFEST.in +6 -0
- alignment_risk-0.1.3/PKG-INFO +194 -0
- alignment_risk-0.1.3/README.md +160 -0
- alignment_risk-0.1.3/pyproject.toml +74 -0
- alignment_risk-0.1.3/setup.cfg +4 -0
- alignment_risk-0.1.3/src/alignment_risk/__about__.py +5 -0
- alignment_risk-0.1.3/src/alignment_risk/__init__.py +50 -0
- alignment_risk-0.1.3/src/alignment_risk/__main__.py +4 -0
- alignment_risk-0.1.3/src/alignment_risk/cli.py +43 -0
- alignment_risk-0.1.3/src/alignment_risk/curvature.py +115 -0
- alignment_risk-0.1.3/src/alignment_risk/demo.py +115 -0
- alignment_risk-0.1.3/src/alignment_risk/fisher.py +545 -0
- alignment_risk-0.1.3/src/alignment_risk/forecast.py +69 -0
- alignment_risk-0.1.3/src/alignment_risk/mitigation.py +222 -0
- alignment_risk-0.1.3/src/alignment_risk/orthogonality.py +57 -0
- alignment_risk-0.1.3/src/alignment_risk/pipeline.py +372 -0
- alignment_risk-0.1.3/src/alignment_risk/py.typed +0 -0
- alignment_risk-0.1.3/src/alignment_risk/types.py +63 -0
- alignment_risk-0.1.3/src/alignment_risk/utils.py +160 -0
- alignment_risk-0.1.3/src/alignment_risk/visualization.py +89 -0
- alignment_risk-0.1.3/src/alignment_risk.egg-info/PKG-INFO +194 -0
- alignment_risk-0.1.3/src/alignment_risk.egg-info/SOURCES.txt +36 -0
- alignment_risk-0.1.3/src/alignment_risk.egg-info/dependency_links.txt +1 -0
- alignment_risk-0.1.3/src/alignment_risk.egg-info/entry_points.txt +2 -0
- alignment_risk-0.1.3/src/alignment_risk.egg-info/requires.txt +14 -0
- alignment_risk-0.1.3/src/alignment_risk.egg-info/top_level.txt +1 -0
- alignment_risk-0.1.3/tests/test_cli.py +24 -0
- alignment_risk-0.1.3/tests/test_curvature.py +71 -0
- alignment_risk-0.1.3/tests/test_device.py +17 -0
- alignment_risk-0.1.3/tests/test_fisher_options.py +230 -0
- alignment_risk-0.1.3/tests/test_fisher_precision.py +44 -0
- alignment_risk-0.1.3/tests/test_forecast.py +45 -0
- alignment_risk-0.1.3/tests/test_mitigation.py +112 -0
- alignment_risk-0.1.3/tests/test_modes.py +56 -0
- alignment_risk-0.1.3/tests/test_orthogonality.py +38 -0
- alignment_risk-0.1.3/tests/test_pipeline_run.py +434 -0
- alignment_risk-0.1.3/tests/test_utils_batch.py +60 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Mohammed Talat
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: alignment-risk
|
|
3
|
+
Version: 0.1.3
|
|
4
|
+
Summary: Template toolkit for AIC-based alignment collapse diagnostics
|
|
5
|
+
Author: Mohammed Talat
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: alignment,llm,safety,fine-tuning,risk,fisher
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: <3.13,>=3.10
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE
|
|
20
|
+
Requires-Dist: torch<2.7,>=2.3
|
|
21
|
+
Requires-Dist: numpy<3.0,>=1.24
|
|
22
|
+
Requires-Dist: matplotlib<3.11,>=3.8
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: build<2.0,>=1.2; extra == "dev"
|
|
25
|
+
Requires-Dist: twine<7.0,>=6.0; extra == "dev"
|
|
26
|
+
Requires-Dist: pytest<9.0,>=8.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-xdist<4.0,>=3.6; extra == "dev"
|
|
28
|
+
Requires-Dist: ruff<1.0,>=0.8; extra == "dev"
|
|
29
|
+
Requires-Dist: mypy<2.0,>=1.11; extra == "dev"
|
|
30
|
+
Requires-Dist: mkdocs<2.0,>=1.6; extra == "dev"
|
|
31
|
+
Requires-Dist: mkdocs-material<10.0,>=9.6; extra == "dev"
|
|
32
|
+
Requires-Dist: pymdown-extensions<11.0,>=10.0; extra == "dev"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# alignment-risk
|
|
36
|
+
|
|
37
|
+
`alignment-risk` is a Python package for pre-flight alignment risk diagnostics during fine-tuning.
|
|
38
|
+
It estimates whether updates are likely to drift into safety-sensitive directions before you commit to a run.
|
|
39
|
+
|
|
40
|
+
Full documentation site: [sirhan1.github.io/modelFineTuneRiskAssessment](https://sirhan1.github.io/modelFineTuneRiskAssessment/)
|
|
41
|
+
|
|
42
|
+
## Quick Start (1 minute)
|
|
43
|
+
|
|
44
|
+
Install:
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
pip install alignment-risk
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
Run the built-in demo:
|
|
51
|
+
|
|
52
|
+
```bash
|
|
53
|
+
alignment-risk demo --output-dir artifacts
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
Outputs:
|
|
57
|
+
- `artifacts/sensitivity_map.png`
|
|
58
|
+
- `artifacts/safety_decay_forecast.png`
|
|
59
|
+
|
|
60
|
+
## Installation
|
|
61
|
+
|
|
62
|
+
PyPI:
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
pip install alignment-risk
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
Local development (Apple Silicon convenience):
|
|
69
|
+
|
|
70
|
+
```bash
|
|
71
|
+
make setup
|
|
72
|
+
source .venv/bin/activate
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
Local development (manual, cross-platform):
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
python -m venv .venv
|
|
79
|
+
source .venv/bin/activate
|
|
80
|
+
pip install -e ".[dev]"
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## Basic Usage (Python API)
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from alignment_risk import AlignmentRiskPipeline, PipelineConfig
|
|
87
|
+
|
|
88
|
+
config = PipelineConfig(mode="lora") # "full" or "lora"
|
|
89
|
+
pipeline = AlignmentRiskPipeline(config)
|
|
90
|
+
|
|
91
|
+
report = pipeline.run(
|
|
92
|
+
model=model,
|
|
93
|
+
safety_dataloader=safety_loader,
|
|
94
|
+
safety_loss_fn=safety_loss_fn,
|
|
95
|
+
fine_tune_dataloader=ft_loader,
|
|
96
|
+
fine_tune_loss_fn=ft_loss_fn,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
print(report.warning)
|
|
100
|
+
print(report.forecast.collapse_step)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## CLI
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
alignment-risk --help
|
|
107
|
+
alignment-risk demo --output-dir artifacts
|
|
108
|
+
alignment-risk demo --mode lora --output-dir artifacts
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
## What It Computes
|
|
112
|
+
|
|
113
|
+
1. Low-rank safety sensitivity subspace from empirical Fisher geometry.
|
|
114
|
+
2. Initial overlap risk (projection of first update into sensitive subspace).
|
|
115
|
+
3. Curvature coupling risk (second-order drift signal).
|
|
116
|
+
4. Quartic-style stability forecast and collapse-step estimate.
|
|
117
|
+
|
|
118
|
+
## Modes
|
|
119
|
+
|
|
120
|
+
- `full`: analyze all selected trainable parameters.
|
|
121
|
+
- `lora`: analyze only trainable LoRA adapter parameters (`lora_`, `lora_A`, `lora_B` by default).
|
|
122
|
+
|
|
123
|
+
## LoRA Mitigation (AlignGuard-style)
|
|
124
|
+
|
|
125
|
+
After `mode="lora"` risk analysis, attach a regularizer to penalize drift in sensitive directions:
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
from alignment_risk import AlignmentRiskPipeline, PipelineConfig, AlignGuardConfig
|
|
129
|
+
|
|
130
|
+
config = PipelineConfig(mode="lora")
|
|
131
|
+
pipeline = AlignmentRiskPipeline(config)
|
|
132
|
+
|
|
133
|
+
report = pipeline.run(
|
|
134
|
+
model=model,
|
|
135
|
+
safety_dataloader=safety_loader,
|
|
136
|
+
safety_loss_fn=safety_loss_fn,
|
|
137
|
+
fine_tune_dataloader=ft_loader,
|
|
138
|
+
fine_tune_loss_fn=ft_loss_fn,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
mitigator = pipeline.build_lora_mitigator(
|
|
142
|
+
model,
|
|
143
|
+
report.subspace,
|
|
144
|
+
config=AlignGuardConfig(lambda_a=0.25, lambda_t=0.5, lambda_nc=0.1, alpha=0.5),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
task_loss = ft_loss_fn(model, batch)
|
|
148
|
+
breakdown = mitigator.regularized_loss(task_loss)
|
|
149
|
+
breakdown.total_loss.backward()
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
Use `mitigator.reset_reference()` to re-anchor regularization at the current adapter state.
|
|
153
|
+
|
|
154
|
+
## Performance / Accuracy Controls
|
|
155
|
+
|
|
156
|
+
`FisherConfig` supports speed/accuracy tradeoffs:
|
|
157
|
+
|
|
158
|
+
- `gradient_collection`: `"loop"` (default), `"auto"`, `"vmap"`.
|
|
159
|
+
- `subspace_method`: `"svd"`, `"randomized_svd"`, `"diag_topk"`.
|
|
160
|
+
- `vmap_chunk_size`: optional chunking for lower memory.
|
|
161
|
+
- `target_explained_variance`: auto-rank selection (default `0.9`).
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
|
|
165
|
+
```python
|
|
166
|
+
from alignment_risk import AlignmentRiskPipeline, PipelineConfig
|
|
167
|
+
|
|
168
|
+
config = PipelineConfig()
|
|
169
|
+
config.fisher.gradient_collection = "vmap"
|
|
170
|
+
config.fisher.subspace_method = "randomized_svd"
|
|
171
|
+
config.fisher.vmap_chunk_size = 16
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
## Theory and Math References
|
|
175
|
+
|
|
176
|
+
- **[AIC-2026]** Springer, Max, et al. (2026). *The Geometry of Alignment Collapse: When Fine-Tuning Breaks Safety*. arXiv:2602.15799v1. PDF: [https://arxiv.org/pdf/2602.15799](https://arxiv.org/pdf/2602.15799)
|
|
177
|
+
- **[ALIGNGUARD-2025]** Das, Amitava, et al. (2025). *AlignGuard-LoRA: Alignment-Preserving Fine-Tuning via Fisher-Guided Decomposition and Riemannian-Geodesic Collision Regularization*. arXiv:2508.02079v1. PDF: [https://arxiv.org/pdf/2508.02079](https://arxiv.org/pdf/2508.02079)
|
|
178
|
+
|
|
179
|
+
Detailed internal mappings and equations:
|
|
180
|
+
- `docs/SOURCES.md`
|
|
181
|
+
- `docs/MATH.md`
|
|
182
|
+
|
|
183
|
+
## Development
|
|
184
|
+
|
|
185
|
+
```bash
|
|
186
|
+
make install
|
|
187
|
+
make test
|
|
188
|
+
make lint
|
|
189
|
+
make typecheck
|
|
190
|
+
make build
|
|
191
|
+
make check-dist
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution workflow details.
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# alignment-risk
|
|
2
|
+
|
|
3
|
+
`alignment-risk` is a Python package for pre-flight alignment risk diagnostics during fine-tuning.
|
|
4
|
+
It estimates whether updates are likely to drift into safety-sensitive directions before you commit to a run.
|
|
5
|
+
|
|
6
|
+
Full documentation site: [sirhan1.github.io/modelFineTuneRiskAssessment](https://sirhan1.github.io/modelFineTuneRiskAssessment/)
|
|
7
|
+
|
|
8
|
+
## Quick Start (1 minute)
|
|
9
|
+
|
|
10
|
+
Install:
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install alignment-risk
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Run the built-in demo:
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
alignment-risk demo --output-dir artifacts
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
Outputs:
|
|
23
|
+
- `artifacts/sensitivity_map.png`
|
|
24
|
+
- `artifacts/safety_decay_forecast.png`
|
|
25
|
+
|
|
26
|
+
## Installation
|
|
27
|
+
|
|
28
|
+
PyPI:
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
pip install alignment-risk
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Local development (Apple Silicon convenience):
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
make setup
|
|
38
|
+
source .venv/bin/activate
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Local development (manual, cross-platform):
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
python -m venv .venv
|
|
45
|
+
source .venv/bin/activate
|
|
46
|
+
pip install -e ".[dev]"
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Basic Usage (Python API)
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
from alignment_risk import AlignmentRiskPipeline, PipelineConfig
|
|
53
|
+
|
|
54
|
+
config = PipelineConfig(mode="lora") # "full" or "lora"
|
|
55
|
+
pipeline = AlignmentRiskPipeline(config)
|
|
56
|
+
|
|
57
|
+
report = pipeline.run(
|
|
58
|
+
model=model,
|
|
59
|
+
safety_dataloader=safety_loader,
|
|
60
|
+
safety_loss_fn=safety_loss_fn,
|
|
61
|
+
fine_tune_dataloader=ft_loader,
|
|
62
|
+
fine_tune_loss_fn=ft_loss_fn,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
print(report.warning)
|
|
66
|
+
print(report.forecast.collapse_step)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## CLI
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
alignment-risk --help
|
|
73
|
+
alignment-risk demo --output-dir artifacts
|
|
74
|
+
alignment-risk demo --mode lora --output-dir artifacts
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
## What It Computes
|
|
78
|
+
|
|
79
|
+
1. Low-rank safety sensitivity subspace from empirical Fisher geometry.
|
|
80
|
+
2. Initial overlap risk (projection of first update into sensitive subspace).
|
|
81
|
+
3. Curvature coupling risk (second-order drift signal).
|
|
82
|
+
4. Quartic-style stability forecast and collapse-step estimate.
|
|
83
|
+
|
|
84
|
+
## Modes
|
|
85
|
+
|
|
86
|
+
- `full`: analyze all selected trainable parameters.
|
|
87
|
+
- `lora`: analyze only trainable LoRA adapter parameters (`lora_`, `lora_A`, `lora_B` by default).
|
|
88
|
+
|
|
89
|
+
## LoRA Mitigation (AlignGuard-style)
|
|
90
|
+
|
|
91
|
+
After `mode="lora"` risk analysis, attach a regularizer to penalize drift in sensitive directions:
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
from alignment_risk import AlignmentRiskPipeline, PipelineConfig, AlignGuardConfig
|
|
95
|
+
|
|
96
|
+
config = PipelineConfig(mode="lora")
|
|
97
|
+
pipeline = AlignmentRiskPipeline(config)
|
|
98
|
+
|
|
99
|
+
report = pipeline.run(
|
|
100
|
+
model=model,
|
|
101
|
+
safety_dataloader=safety_loader,
|
|
102
|
+
safety_loss_fn=safety_loss_fn,
|
|
103
|
+
fine_tune_dataloader=ft_loader,
|
|
104
|
+
fine_tune_loss_fn=ft_loss_fn,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
mitigator = pipeline.build_lora_mitigator(
|
|
108
|
+
model,
|
|
109
|
+
report.subspace,
|
|
110
|
+
config=AlignGuardConfig(lambda_a=0.25, lambda_t=0.5, lambda_nc=0.1, alpha=0.5),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
task_loss = ft_loss_fn(model, batch)
|
|
114
|
+
breakdown = mitigator.regularized_loss(task_loss)
|
|
115
|
+
breakdown.total_loss.backward()
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
Use `mitigator.reset_reference()` to re-anchor regularization at the current adapter state.
|
|
119
|
+
|
|
120
|
+
## Performance / Accuracy Controls
|
|
121
|
+
|
|
122
|
+
`FisherConfig` supports speed/accuracy tradeoffs:
|
|
123
|
+
|
|
124
|
+
- `gradient_collection`: `"loop"` (default), `"auto"`, `"vmap"`.
|
|
125
|
+
- `subspace_method`: `"svd"`, `"randomized_svd"`, `"diag_topk"`.
|
|
126
|
+
- `vmap_chunk_size`: optional chunking for lower memory.
|
|
127
|
+
- `target_explained_variance`: auto-rank selection (default `0.9`).
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
|
|
131
|
+
```python
|
|
132
|
+
from alignment_risk import AlignmentRiskPipeline, PipelineConfig
|
|
133
|
+
|
|
134
|
+
config = PipelineConfig()
|
|
135
|
+
config.fisher.gradient_collection = "vmap"
|
|
136
|
+
config.fisher.subspace_method = "randomized_svd"
|
|
137
|
+
config.fisher.vmap_chunk_size = 16
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
## Theory and Math References
|
|
141
|
+
|
|
142
|
+
- **[AIC-2026]** Springer, Max, et al. (2026). *The Geometry of Alignment Collapse: When Fine-Tuning Breaks Safety*. arXiv:2602.15799v1. PDF: [https://arxiv.org/pdf/2602.15799](https://arxiv.org/pdf/2602.15799)
|
|
143
|
+
- **[ALIGNGUARD-2025]** Das, Amitava, et al. (2025). *AlignGuard-LoRA: Alignment-Preserving Fine-Tuning via Fisher-Guided Decomposition and Riemannian-Geodesic Collision Regularization*. arXiv:2508.02079v1. PDF: [https://arxiv.org/pdf/2508.02079](https://arxiv.org/pdf/2508.02079)
|
|
144
|
+
|
|
145
|
+
Detailed internal mappings and equations:
|
|
146
|
+
- `docs/SOURCES.md`
|
|
147
|
+
- `docs/MATH.md`
|
|
148
|
+
|
|
149
|
+
## Development
|
|
150
|
+
|
|
151
|
+
```bash
|
|
152
|
+
make install
|
|
153
|
+
make test
|
|
154
|
+
make lint
|
|
155
|
+
make typecheck
|
|
156
|
+
make build
|
|
157
|
+
make check-dist
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution workflow details.
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "alignment-risk"
|
|
7
|
+
version = "0.1.3"
|
|
8
|
+
description = "Template toolkit for AIC-based alignment collapse diagnostics"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
authors = [{name = "Mohammed Talat"}]
|
|
11
|
+
license = "MIT"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
requires-python = ">=3.10,<3.13"
|
|
14
|
+
keywords = ["alignment", "llm", "safety", "fine-tuning", "risk", "fisher"]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Development Status :: 3 - Alpha",
|
|
17
|
+
"Intended Audience :: Developers",
|
|
18
|
+
"Intended Audience :: Science/Research",
|
|
19
|
+
"Operating System :: OS Independent",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Programming Language :: Python :: 3.12",
|
|
24
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
25
|
+
]
|
|
26
|
+
dependencies = [
|
|
27
|
+
"torch>=2.3,<2.7",
|
|
28
|
+
"numpy>=1.24,<3.0",
|
|
29
|
+
"matplotlib>=3.8,<3.11",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
[project.optional-dependencies]
|
|
33
|
+
dev = [
|
|
34
|
+
"build>=1.2,<2.0",
|
|
35
|
+
"twine>=6.0,<7.0",
|
|
36
|
+
"pytest>=8.0,<9.0",
|
|
37
|
+
"pytest-xdist>=3.6,<4.0",
|
|
38
|
+
"ruff>=0.8,<1.0",
|
|
39
|
+
"mypy>=1.11,<2.0",
|
|
40
|
+
"mkdocs>=1.6,<2.0",
|
|
41
|
+
"mkdocs-material>=9.6,<10.0",
|
|
42
|
+
"pymdown-extensions>=10.0,<11.0",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[project.scripts]
|
|
46
|
+
alignment-risk = "alignment_risk.cli:main"
|
|
47
|
+
|
|
48
|
+
[tool.setuptools]
|
|
49
|
+
package-dir = {"" = "src"}
|
|
50
|
+
|
|
51
|
+
[tool.setuptools.packages.find]
|
|
52
|
+
where = ["src"]
|
|
53
|
+
|
|
54
|
+
[tool.setuptools.package-data]
|
|
55
|
+
alignment_risk = ["py.typed"]
|
|
56
|
+
|
|
57
|
+
[tool.pytest.ini_options]
|
|
58
|
+
pythonpath = ["src"]
|
|
59
|
+
addopts = "-q"
|
|
60
|
+
|
|
61
|
+
[tool.ruff]
|
|
62
|
+
line-length = 100
|
|
63
|
+
target-version = "py310"
|
|
64
|
+
|
|
65
|
+
[tool.ruff.lint]
|
|
66
|
+
select = ["E", "F", "I"]
|
|
67
|
+
ignore = ["E501"]
|
|
68
|
+
|
|
69
|
+
[tool.mypy]
|
|
70
|
+
python_version = "3.10"
|
|
71
|
+
ignore_missing_imports = true
|
|
72
|
+
warn_unused_ignores = true
|
|
73
|
+
warn_redundant_casts = true
|
|
74
|
+
warn_return_any = false
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from .__about__ import __version__
|
|
2
|
+
from .curvature import CurvatureConfig, CurvatureCouplingAnalyzer
|
|
3
|
+
from .fisher import FisherConfig, FisherSubspaceAnalyzer
|
|
4
|
+
from .forecast import ForecastConfig, forecast_stability
|
|
5
|
+
from .mitigation import (
|
|
6
|
+
AlignGuardConfig,
|
|
7
|
+
AlignGuardLoRARegularizer,
|
|
8
|
+
AlignGuardLossBreakdown,
|
|
9
|
+
decompose_update,
|
|
10
|
+
fisher_weighted_alignment_penalty,
|
|
11
|
+
geodesic_overlap_penalty,
|
|
12
|
+
project_onto_subspace,
|
|
13
|
+
riemannian_overlap_penalty,
|
|
14
|
+
)
|
|
15
|
+
from .pipeline import AlignmentRiskPipeline, PipelineConfig, PipelineMode
|
|
16
|
+
from .types import (
|
|
17
|
+
CurvatureCouplingResult,
|
|
18
|
+
InitialRiskScore,
|
|
19
|
+
ParameterSlice,
|
|
20
|
+
RiskAssessmentReport,
|
|
21
|
+
SafetyForecast,
|
|
22
|
+
SensitivitySubspace,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"__version__",
|
|
27
|
+
"AlignmentRiskPipeline",
|
|
28
|
+
"PipelineConfig",
|
|
29
|
+
"PipelineMode",
|
|
30
|
+
"AlignGuardConfig",
|
|
31
|
+
"AlignGuardLoRARegularizer",
|
|
32
|
+
"AlignGuardLossBreakdown",
|
|
33
|
+
"project_onto_subspace",
|
|
34
|
+
"decompose_update",
|
|
35
|
+
"fisher_weighted_alignment_penalty",
|
|
36
|
+
"riemannian_overlap_penalty",
|
|
37
|
+
"geodesic_overlap_penalty",
|
|
38
|
+
"FisherSubspaceAnalyzer",
|
|
39
|
+
"FisherConfig",
|
|
40
|
+
"CurvatureCouplingAnalyzer",
|
|
41
|
+
"CurvatureConfig",
|
|
42
|
+
"ForecastConfig",
|
|
43
|
+
"forecast_stability",
|
|
44
|
+
"SensitivitySubspace",
|
|
45
|
+
"InitialRiskScore",
|
|
46
|
+
"CurvatureCouplingResult",
|
|
47
|
+
"SafetyForecast",
|
|
48
|
+
"RiskAssessmentReport",
|
|
49
|
+
"ParameterSlice",
|
|
50
|
+
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from typing import Literal, cast
|
|
5
|
+
|
|
6
|
+
from .demo import run_demo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
10
|
+
parser = argparse.ArgumentParser(description="AIC alignment risk template")
|
|
11
|
+
sub = parser.add_subparsers(dest="command")
|
|
12
|
+
|
|
13
|
+
demo = sub.add_parser("demo", help="run a synthetic end-to-end diagnostic")
|
|
14
|
+
demo.add_argument("--output-dir", default="artifacts", help="where to save plots")
|
|
15
|
+
demo.add_argument(
|
|
16
|
+
"--mode",
|
|
17
|
+
default="full",
|
|
18
|
+
choices=["full", "lora"],
|
|
19
|
+
help="analysis mode: full fine-tuning weights or LoRA adapter-only weights",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
return parser
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main() -> None:
|
|
26
|
+
parser = build_parser()
|
|
27
|
+
args = parser.parse_args()
|
|
28
|
+
|
|
29
|
+
if args.command is None:
|
|
30
|
+
parser.print_help()
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
if args.command == "demo":
|
|
34
|
+
output_dir = getattr(args, "output_dir", "artifacts")
|
|
35
|
+
mode = cast(str, getattr(args, "mode", "full"))
|
|
36
|
+
run_demo(output_dir=output_dir, mode=cast(Literal["full", "lora"], mode))
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
parser.error(f"Unknown command: {args.command}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
if __name__ == "__main__":
|
|
43
|
+
main()
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Curvature coupling diagnostics for alignment risk.
|
|
2
|
+
|
|
3
|
+
Academic grounding:
|
|
4
|
+
- [AIC-2026] https://arxiv.org/pdf/2602.15799
|
|
5
|
+
|
|
6
|
+
See docs/SOURCES.md for section/page-level mapping to this module.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Callable, Iterable
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
from .types import CurvatureCouplingResult, SensitivitySubspace
|
|
17
|
+
from .utils import flatten_tensors, move_to_device, named_trainable_parameters, resolve_device
|
|
18
|
+
|
|
19
|
+
LossFn = Callable[[torch.nn.Module, object], torch.Tensor]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class CurvatureConfig:
|
|
24
|
+
device: str = "auto"
|
|
25
|
+
max_batches: int = 1
|
|
26
|
+
force_eval: bool = True
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CurvatureCouplingAnalyzer:
|
|
30
|
+
"""Estimate AIC Condition 3 via directional derivative H g."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, config: CurvatureConfig | None = None):
|
|
33
|
+
self.config = config or CurvatureConfig()
|
|
34
|
+
|
|
35
|
+
def analyze(
|
|
36
|
+
self,
|
|
37
|
+
model: torch.nn.Module,
|
|
38
|
+
dataloader: Iterable[object],
|
|
39
|
+
loss_fn: LossFn,
|
|
40
|
+
subspace: SensitivitySubspace,
|
|
41
|
+
*,
|
|
42
|
+
max_batches_override: int | None = None,
|
|
43
|
+
) -> CurvatureCouplingResult:
|
|
44
|
+
device = resolve_device(self.config.device)
|
|
45
|
+
model = model.to(device)
|
|
46
|
+
max_batches = max_batches_override if max_batches_override is not None else self.config.max_batches
|
|
47
|
+
|
|
48
|
+
selected_names = [p.name for p in subspace.parameter_slices]
|
|
49
|
+
_, params = named_trainable_parameters(model, include_names=selected_names)
|
|
50
|
+
if not params:
|
|
51
|
+
raise ValueError("No trainable parameters found for curvature analysis.")
|
|
52
|
+
|
|
53
|
+
was_training = model.training
|
|
54
|
+
if self.config.force_eval:
|
|
55
|
+
model.eval()
|
|
56
|
+
try:
|
|
57
|
+
loss = self._mean_loss(model, dataloader, loss_fn, device, max_batches=max_batches)
|
|
58
|
+
finally:
|
|
59
|
+
if self.config.force_eval and was_training:
|
|
60
|
+
model.train()
|
|
61
|
+
|
|
62
|
+
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
|
|
63
|
+
g = flatten_tensors(grads, params)
|
|
64
|
+
|
|
65
|
+
# Directional derivative of the gradient field along itself: (∇g)g = H g.
|
|
66
|
+
dot = torch.dot(g, g.detach())
|
|
67
|
+
hvp = torch.autograd.grad(dot, params, allow_unused=True)
|
|
68
|
+
|
|
69
|
+
basis = subspace.fisher_eigenvectors.detach().cpu()
|
|
70
|
+
eigvals = torch.clamp(subspace.fisher_eigenvalues.detach().cpu(), min=0.0).to(dtype=basis.dtype)
|
|
71
|
+
|
|
72
|
+
g_vec = g.detach().to(device=basis.device, dtype=basis.dtype)
|
|
73
|
+
a_vec = flatten_tensors(hvp, params).detach().to(device=basis.device, dtype=basis.dtype)
|
|
74
|
+
|
|
75
|
+
g_coeff = basis.T @ g_vec
|
|
76
|
+
epsilon_hat = float(torch.sqrt(torch.sum(eigvals * (g_coeff ** 2))).item())
|
|
77
|
+
|
|
78
|
+
a_coeff = basis.T @ a_vec
|
|
79
|
+
projected_a = basis @ a_coeff
|
|
80
|
+
|
|
81
|
+
gamma_hat = float(torch.sqrt(torch.sum(eigvals * (a_coeff ** 2))).item())
|
|
82
|
+
projected_acc_norm = float(projected_a.norm().item())
|
|
83
|
+
|
|
84
|
+
return CurvatureCouplingResult(
|
|
85
|
+
gamma_hat=gamma_hat,
|
|
86
|
+
epsilon_hat=epsilon_hat,
|
|
87
|
+
acceleration_norm=float(a_vec.norm().item()),
|
|
88
|
+
projected_acceleration_norm=projected_acc_norm,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _mean_loss(
|
|
92
|
+
self,
|
|
93
|
+
model: torch.nn.Module,
|
|
94
|
+
dataloader: Iterable[object],
|
|
95
|
+
loss_fn: LossFn,
|
|
96
|
+
device: torch.device,
|
|
97
|
+
*,
|
|
98
|
+
max_batches: int,
|
|
99
|
+
) -> torch.Tensor:
|
|
100
|
+
total = None
|
|
101
|
+
count = 0
|
|
102
|
+
for i, batch in enumerate(dataloader):
|
|
103
|
+
if i >= max_batches:
|
|
104
|
+
break
|
|
105
|
+
batch = move_to_device(batch, device)
|
|
106
|
+
loss = loss_fn(model, batch)
|
|
107
|
+
if loss.ndim > 0:
|
|
108
|
+
loss = loss.mean()
|
|
109
|
+
total = loss if total is None else total + loss
|
|
110
|
+
count += 1
|
|
111
|
+
|
|
112
|
+
if count == 0 or total is None:
|
|
113
|
+
raise ValueError("Fine-tuning dataloader yielded no batches.")
|
|
114
|
+
|
|
115
|
+
return total / float(count)
|