alignment-risk 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. alignment_risk-0.1.3/LICENSE +21 -0
  2. alignment_risk-0.1.3/MANIFEST.in +6 -0
  3. alignment_risk-0.1.3/PKG-INFO +194 -0
  4. alignment_risk-0.1.3/README.md +160 -0
  5. alignment_risk-0.1.3/pyproject.toml +74 -0
  6. alignment_risk-0.1.3/setup.cfg +4 -0
  7. alignment_risk-0.1.3/src/alignment_risk/__about__.py +5 -0
  8. alignment_risk-0.1.3/src/alignment_risk/__init__.py +50 -0
  9. alignment_risk-0.1.3/src/alignment_risk/__main__.py +4 -0
  10. alignment_risk-0.1.3/src/alignment_risk/cli.py +43 -0
  11. alignment_risk-0.1.3/src/alignment_risk/curvature.py +115 -0
  12. alignment_risk-0.1.3/src/alignment_risk/demo.py +115 -0
  13. alignment_risk-0.1.3/src/alignment_risk/fisher.py +545 -0
  14. alignment_risk-0.1.3/src/alignment_risk/forecast.py +69 -0
  15. alignment_risk-0.1.3/src/alignment_risk/mitigation.py +222 -0
  16. alignment_risk-0.1.3/src/alignment_risk/orthogonality.py +57 -0
  17. alignment_risk-0.1.3/src/alignment_risk/pipeline.py +372 -0
  18. alignment_risk-0.1.3/src/alignment_risk/py.typed +0 -0
  19. alignment_risk-0.1.3/src/alignment_risk/types.py +63 -0
  20. alignment_risk-0.1.3/src/alignment_risk/utils.py +160 -0
  21. alignment_risk-0.1.3/src/alignment_risk/visualization.py +89 -0
  22. alignment_risk-0.1.3/src/alignment_risk.egg-info/PKG-INFO +194 -0
  23. alignment_risk-0.1.3/src/alignment_risk.egg-info/SOURCES.txt +36 -0
  24. alignment_risk-0.1.3/src/alignment_risk.egg-info/dependency_links.txt +1 -0
  25. alignment_risk-0.1.3/src/alignment_risk.egg-info/entry_points.txt +2 -0
  26. alignment_risk-0.1.3/src/alignment_risk.egg-info/requires.txt +14 -0
  27. alignment_risk-0.1.3/src/alignment_risk.egg-info/top_level.txt +1 -0
  28. alignment_risk-0.1.3/tests/test_cli.py +24 -0
  29. alignment_risk-0.1.3/tests/test_curvature.py +71 -0
  30. alignment_risk-0.1.3/tests/test_device.py +17 -0
  31. alignment_risk-0.1.3/tests/test_fisher_options.py +230 -0
  32. alignment_risk-0.1.3/tests/test_fisher_precision.py +44 -0
  33. alignment_risk-0.1.3/tests/test_forecast.py +45 -0
  34. alignment_risk-0.1.3/tests/test_mitigation.py +112 -0
  35. alignment_risk-0.1.3/tests/test_modes.py +56 -0
  36. alignment_risk-0.1.3/tests/test_orthogonality.py +38 -0
  37. alignment_risk-0.1.3/tests/test_pipeline_run.py +434 -0
  38. alignment_risk-0.1.3/tests/test_utils_batch.py +60 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Mohammed Talat
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,6 @@
1
+ include README.md
2
+ include LICENSE
3
+ recursive-include src/alignment_risk *.py
4
+ include src/alignment_risk/py.typed
5
+ global-exclude __pycache__
6
+ global-exclude *.py[cod]
@@ -0,0 +1,194 @@
1
+ Metadata-Version: 2.4
2
+ Name: alignment-risk
3
+ Version: 0.1.3
4
+ Summary: Template toolkit for AIC-based alignment collapse diagnostics
5
+ Author: Mohammed Talat
6
+ License-Expression: MIT
7
+ Keywords: alignment,llm,safety,fine-tuning,risk,fisher
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: <3.13,>=3.10
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: torch<2.7,>=2.3
21
+ Requires-Dist: numpy<3.0,>=1.24
22
+ Requires-Dist: matplotlib<3.11,>=3.8
23
+ Provides-Extra: dev
24
+ Requires-Dist: build<2.0,>=1.2; extra == "dev"
25
+ Requires-Dist: twine<7.0,>=6.0; extra == "dev"
26
+ Requires-Dist: pytest<9.0,>=8.0; extra == "dev"
27
+ Requires-Dist: pytest-xdist<4.0,>=3.6; extra == "dev"
28
+ Requires-Dist: ruff<1.0,>=0.8; extra == "dev"
29
+ Requires-Dist: mypy<2.0,>=1.11; extra == "dev"
30
+ Requires-Dist: mkdocs<2.0,>=1.6; extra == "dev"
31
+ Requires-Dist: mkdocs-material<10.0,>=9.6; extra == "dev"
32
+ Requires-Dist: pymdown-extensions<11.0,>=10.0; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # alignment-risk
36
+
37
+ `alignment-risk` is a Python package for pre-flight alignment risk diagnostics during fine-tuning.
38
+ It estimates whether updates are likely to drift into safety-sensitive directions before you commit to a run.
39
+
40
+ Full documentation site: [sirhan1.github.io/modelFineTuneRiskAssessment](https://sirhan1.github.io/modelFineTuneRiskAssessment/)
41
+
42
+ ## Quick Start (1 minute)
43
+
44
+ Install:
45
+
46
+ ```bash
47
+ pip install alignment-risk
48
+ ```
49
+
50
+ Run the built-in demo:
51
+
52
+ ```bash
53
+ alignment-risk demo --output-dir artifacts
54
+ ```
55
+
56
+ Outputs:
57
+ - `artifacts/sensitivity_map.png`
58
+ - `artifacts/safety_decay_forecast.png`
59
+
60
+ ## Installation
61
+
62
+ PyPI:
63
+
64
+ ```bash
65
+ pip install alignment-risk
66
+ ```
67
+
68
+ Local development (Apple Silicon convenience):
69
+
70
+ ```bash
71
+ make setup
72
+ source .venv/bin/activate
73
+ ```
74
+
75
+ Local development (manual, cross-platform):
76
+
77
+ ```bash
78
+ python -m venv .venv
79
+ source .venv/bin/activate
80
+ pip install -e ".[dev]"
81
+ ```
82
+
83
+ ## Basic Usage (Python API)
84
+
85
+ ```python
86
+ from alignment_risk import AlignmentRiskPipeline, PipelineConfig
87
+
88
+ config = PipelineConfig(mode="lora") # "full" or "lora"
89
+ pipeline = AlignmentRiskPipeline(config)
90
+
91
+ report = pipeline.run(
92
+ model=model,
93
+ safety_dataloader=safety_loader,
94
+ safety_loss_fn=safety_loss_fn,
95
+ fine_tune_dataloader=ft_loader,
96
+ fine_tune_loss_fn=ft_loss_fn,
97
+ )
98
+
99
+ print(report.warning)
100
+ print(report.forecast.collapse_step)
101
+ ```
102
+
103
+ ## CLI
104
+
105
+ ```bash
106
+ alignment-risk --help
107
+ alignment-risk demo --output-dir artifacts
108
+ alignment-risk demo --mode lora --output-dir artifacts
109
+ ```
110
+
111
+ ## What It Computes
112
+
113
+ 1. Low-rank safety sensitivity subspace from empirical Fisher geometry.
114
+ 2. Initial overlap risk (projection of first update into sensitive subspace).
115
+ 3. Curvature coupling risk (second-order drift signal).
116
+ 4. Quartic-style stability forecast and collapse-step estimate.
117
+
118
+ ## Modes
119
+
120
+ - `full`: analyze all selected trainable parameters.
121
+ - `lora`: analyze only trainable LoRA adapter parameters (`lora_`, `lora_A`, `lora_B` by default).
122
+
123
+ ## LoRA Mitigation (AlignGuard-style)
124
+
125
+ After `mode="lora"` risk analysis, attach a regularizer to penalize drift in sensitive directions:
126
+
127
+ ```python
128
+ from alignment_risk import AlignmentRiskPipeline, PipelineConfig, AlignGuardConfig
129
+
130
+ config = PipelineConfig(mode="lora")
131
+ pipeline = AlignmentRiskPipeline(config)
132
+
133
+ report = pipeline.run(
134
+ model=model,
135
+ safety_dataloader=safety_loader,
136
+ safety_loss_fn=safety_loss_fn,
137
+ fine_tune_dataloader=ft_loader,
138
+ fine_tune_loss_fn=ft_loss_fn,
139
+ )
140
+
141
+ mitigator = pipeline.build_lora_mitigator(
142
+ model,
143
+ report.subspace,
144
+ config=AlignGuardConfig(lambda_a=0.25, lambda_t=0.5, lambda_nc=0.1, alpha=0.5),
145
+ )
146
+
147
+ task_loss = ft_loss_fn(model, batch)
148
+ breakdown = mitigator.regularized_loss(task_loss)
149
+ breakdown.total_loss.backward()
150
+ ```
151
+
152
+ Use `mitigator.reset_reference()` to re-anchor regularization at the current adapter state.
153
+
154
+ ## Performance / Accuracy Controls
155
+
156
+ `FisherConfig` supports speed/accuracy tradeoffs:
157
+
158
+ - `gradient_collection`: `"loop"` (default), `"auto"`, `"vmap"`.
159
+ - `subspace_method`: `"svd"`, `"randomized_svd"`, `"diag_topk"`.
160
+ - `vmap_chunk_size`: optional chunking for lower memory.
161
+ - `target_explained_variance`: auto-rank selection (default `0.9`).
162
+
163
+ Example:
164
+
165
+ ```python
166
+ from alignment_risk import AlignmentRiskPipeline, PipelineConfig
167
+
168
+ config = PipelineConfig()
169
+ config.fisher.gradient_collection = "vmap"
170
+ config.fisher.subspace_method = "randomized_svd"
171
+ config.fisher.vmap_chunk_size = 16
172
+ ```
173
+
174
+ ## Theory and Math References
175
+
176
+ - **[AIC-2026]** Springer, Max, et al. (2026). *The Geometry of Alignment Collapse: When Fine-Tuning Breaks Safety*. arXiv:2602.15799v1. PDF: [https://arxiv.org/pdf/2602.15799](https://arxiv.org/pdf/2602.15799)
177
+ - **[ALIGNGUARD-2025]** Das, Amitava, et al. (2025). *AlignGuard-LoRA: Alignment-Preserving Fine-Tuning via Fisher-Guided Decomposition and Riemannian-Geodesic Collision Regularization*. arXiv:2508.02079v1. PDF: [https://arxiv.org/pdf/2508.02079](https://arxiv.org/pdf/2508.02079)
178
+
179
+ Detailed internal mappings and equations:
180
+ - `docs/SOURCES.md`
181
+ - `docs/MATH.md`
182
+
183
+ ## Development
184
+
185
+ ```bash
186
+ make install
187
+ make test
188
+ make lint
189
+ make typecheck
190
+ make build
191
+ make check-dist
192
+ ```
193
+
194
+ See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution workflow details.
@@ -0,0 +1,160 @@
1
+ # alignment-risk
2
+
3
+ `alignment-risk` is a Python package for pre-flight alignment risk diagnostics during fine-tuning.
4
+ It estimates whether updates are likely to drift into safety-sensitive directions before you commit to a run.
5
+
6
+ Full documentation site: [sirhan1.github.io/modelFineTuneRiskAssessment](https://sirhan1.github.io/modelFineTuneRiskAssessment/)
7
+
8
+ ## Quick Start (1 minute)
9
+
10
+ Install:
11
+
12
+ ```bash
13
+ pip install alignment-risk
14
+ ```
15
+
16
+ Run the built-in demo:
17
+
18
+ ```bash
19
+ alignment-risk demo --output-dir artifacts
20
+ ```
21
+
22
+ Outputs:
23
+ - `artifacts/sensitivity_map.png`
24
+ - `artifacts/safety_decay_forecast.png`
25
+
26
+ ## Installation
27
+
28
+ PyPI:
29
+
30
+ ```bash
31
+ pip install alignment-risk
32
+ ```
33
+
34
+ Local development (Apple Silicon convenience):
35
+
36
+ ```bash
37
+ make setup
38
+ source .venv/bin/activate
39
+ ```
40
+
41
+ Local development (manual, cross-platform):
42
+
43
+ ```bash
44
+ python -m venv .venv
45
+ source .venv/bin/activate
46
+ pip install -e ".[dev]"
47
+ ```
48
+
49
+ ## Basic Usage (Python API)
50
+
51
+ ```python
52
+ from alignment_risk import AlignmentRiskPipeline, PipelineConfig
53
+
54
+ config = PipelineConfig(mode="lora") # "full" or "lora"
55
+ pipeline = AlignmentRiskPipeline(config)
56
+
57
+ report = pipeline.run(
58
+ model=model,
59
+ safety_dataloader=safety_loader,
60
+ safety_loss_fn=safety_loss_fn,
61
+ fine_tune_dataloader=ft_loader,
62
+ fine_tune_loss_fn=ft_loss_fn,
63
+ )
64
+
65
+ print(report.warning)
66
+ print(report.forecast.collapse_step)
67
+ ```
68
+
69
+ ## CLI
70
+
71
+ ```bash
72
+ alignment-risk --help
73
+ alignment-risk demo --output-dir artifacts
74
+ alignment-risk demo --mode lora --output-dir artifacts
75
+ ```
76
+
77
+ ## What It Computes
78
+
79
+ 1. Low-rank safety sensitivity subspace from empirical Fisher geometry.
80
+ 2. Initial overlap risk (projection of first update into sensitive subspace).
81
+ 3. Curvature coupling risk (second-order drift signal).
82
+ 4. Quartic-style stability forecast and collapse-step estimate.
83
+
84
+ ## Modes
85
+
86
+ - `full`: analyze all selected trainable parameters.
87
+ - `lora`: analyze only trainable LoRA adapter parameters (`lora_`, `lora_A`, `lora_B` by default).
88
+
89
+ ## LoRA Mitigation (AlignGuard-style)
90
+
91
+ After `mode="lora"` risk analysis, attach a regularizer to penalize drift in sensitive directions:
92
+
93
+ ```python
94
+ from alignment_risk import AlignmentRiskPipeline, PipelineConfig, AlignGuardConfig
95
+
96
+ config = PipelineConfig(mode="lora")
97
+ pipeline = AlignmentRiskPipeline(config)
98
+
99
+ report = pipeline.run(
100
+ model=model,
101
+ safety_dataloader=safety_loader,
102
+ safety_loss_fn=safety_loss_fn,
103
+ fine_tune_dataloader=ft_loader,
104
+ fine_tune_loss_fn=ft_loss_fn,
105
+ )
106
+
107
+ mitigator = pipeline.build_lora_mitigator(
108
+ model,
109
+ report.subspace,
110
+ config=AlignGuardConfig(lambda_a=0.25, lambda_t=0.5, lambda_nc=0.1, alpha=0.5),
111
+ )
112
+
113
+ task_loss = ft_loss_fn(model, batch)
114
+ breakdown = mitigator.regularized_loss(task_loss)
115
+ breakdown.total_loss.backward()
116
+ ```
117
+
118
+ Use `mitigator.reset_reference()` to re-anchor regularization at the current adapter state.
119
+
120
+ ## Performance / Accuracy Controls
121
+
122
+ `FisherConfig` supports speed/accuracy tradeoffs:
123
+
124
+ - `gradient_collection`: `"loop"` (default), `"auto"`, `"vmap"`.
125
+ - `subspace_method`: `"svd"`, `"randomized_svd"`, `"diag_topk"`.
126
+ - `vmap_chunk_size`: optional chunking for lower memory.
127
+ - `target_explained_variance`: auto-rank selection (default `0.9`).
128
+
129
+ Example:
130
+
131
+ ```python
132
+ from alignment_risk import AlignmentRiskPipeline, PipelineConfig
133
+
134
+ config = PipelineConfig()
135
+ config.fisher.gradient_collection = "vmap"
136
+ config.fisher.subspace_method = "randomized_svd"
137
+ config.fisher.vmap_chunk_size = 16
138
+ ```
139
+
140
+ ## Theory and Math References
141
+
142
+ - **[AIC-2026]** Springer, Max, et al. (2026). *The Geometry of Alignment Collapse: When Fine-Tuning Breaks Safety*. arXiv:2602.15799v1. PDF: [https://arxiv.org/pdf/2602.15799](https://arxiv.org/pdf/2602.15799)
143
+ - **[ALIGNGUARD-2025]** Das, Amitava, et al. (2025). *AlignGuard-LoRA: Alignment-Preserving Fine-Tuning via Fisher-Guided Decomposition and Riemannian-Geodesic Collision Regularization*. arXiv:2508.02079v1. PDF: [https://arxiv.org/pdf/2508.02079](https://arxiv.org/pdf/2508.02079)
144
+
145
+ Detailed internal mappings and equations:
146
+ - `docs/SOURCES.md`
147
+ - `docs/MATH.md`
148
+
149
+ ## Development
150
+
151
+ ```bash
152
+ make install
153
+ make test
154
+ make lint
155
+ make typecheck
156
+ make build
157
+ make check-dist
158
+ ```
159
+
160
+ See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution workflow details.
@@ -0,0 +1,74 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "alignment-risk"
7
+ version = "0.1.3"
8
+ description = "Template toolkit for AIC-based alignment collapse diagnostics"
9
+ readme = "README.md"
10
+ authors = [{name = "Mohammed Talat"}]
11
+ license = "MIT"
12
+ license-files = ["LICENSE"]
13
+ requires-python = ">=3.10,<3.13"
14
+ keywords = ["alignment", "llm", "safety", "fine-tuning", "risk", "fisher"]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Intended Audience :: Developers",
18
+ "Intended Audience :: Science/Research",
19
+ "Operating System :: OS Independent",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.10",
22
+ "Programming Language :: Python :: 3.11",
23
+ "Programming Language :: Python :: 3.12",
24
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
25
+ ]
26
+ dependencies = [
27
+ "torch>=2.3,<2.7",
28
+ "numpy>=1.24,<3.0",
29
+ "matplotlib>=3.8,<3.11",
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = [
34
+ "build>=1.2,<2.0",
35
+ "twine>=6.0,<7.0",
36
+ "pytest>=8.0,<9.0",
37
+ "pytest-xdist>=3.6,<4.0",
38
+ "ruff>=0.8,<1.0",
39
+ "mypy>=1.11,<2.0",
40
+ "mkdocs>=1.6,<2.0",
41
+ "mkdocs-material>=9.6,<10.0",
42
+ "pymdown-extensions>=10.0,<11.0",
43
+ ]
44
+
45
+ [project.scripts]
46
+ alignment-risk = "alignment_risk.cli:main"
47
+
48
+ [tool.setuptools]
49
+ package-dir = {"" = "src"}
50
+
51
+ [tool.setuptools.packages.find]
52
+ where = ["src"]
53
+
54
+ [tool.setuptools.package-data]
55
+ alignment_risk = ["py.typed"]
56
+
57
+ [tool.pytest.ini_options]
58
+ pythonpath = ["src"]
59
+ addopts = "-q"
60
+
61
+ [tool.ruff]
62
+ line-length = 100
63
+ target-version = "py310"
64
+
65
+ [tool.ruff.lint]
66
+ select = ["E", "F", "I"]
67
+ ignore = ["E501"]
68
+
69
+ [tool.mypy]
70
+ python_version = "3.10"
71
+ ignore_missing_imports = true
72
+ warn_unused_ignores = true
73
+ warn_redundant_casts = true
74
+ warn_return_any = false
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,5 @@
1
+ """Package metadata for alignment_risk."""
2
+
3
+ __all__ = ["__version__"]
4
+
5
+ __version__ = "0.1.3"
@@ -0,0 +1,50 @@
1
+ from .__about__ import __version__
2
+ from .curvature import CurvatureConfig, CurvatureCouplingAnalyzer
3
+ from .fisher import FisherConfig, FisherSubspaceAnalyzer
4
+ from .forecast import ForecastConfig, forecast_stability
5
+ from .mitigation import (
6
+ AlignGuardConfig,
7
+ AlignGuardLoRARegularizer,
8
+ AlignGuardLossBreakdown,
9
+ decompose_update,
10
+ fisher_weighted_alignment_penalty,
11
+ geodesic_overlap_penalty,
12
+ project_onto_subspace,
13
+ riemannian_overlap_penalty,
14
+ )
15
+ from .pipeline import AlignmentRiskPipeline, PipelineConfig, PipelineMode
16
+ from .types import (
17
+ CurvatureCouplingResult,
18
+ InitialRiskScore,
19
+ ParameterSlice,
20
+ RiskAssessmentReport,
21
+ SafetyForecast,
22
+ SensitivitySubspace,
23
+ )
24
+
25
+ __all__ = [
26
+ "__version__",
27
+ "AlignmentRiskPipeline",
28
+ "PipelineConfig",
29
+ "PipelineMode",
30
+ "AlignGuardConfig",
31
+ "AlignGuardLoRARegularizer",
32
+ "AlignGuardLossBreakdown",
33
+ "project_onto_subspace",
34
+ "decompose_update",
35
+ "fisher_weighted_alignment_penalty",
36
+ "riemannian_overlap_penalty",
37
+ "geodesic_overlap_penalty",
38
+ "FisherSubspaceAnalyzer",
39
+ "FisherConfig",
40
+ "CurvatureCouplingAnalyzer",
41
+ "CurvatureConfig",
42
+ "ForecastConfig",
43
+ "forecast_stability",
44
+ "SensitivitySubspace",
45
+ "InitialRiskScore",
46
+ "CurvatureCouplingResult",
47
+ "SafetyForecast",
48
+ "RiskAssessmentReport",
49
+ "ParameterSlice",
50
+ ]
@@ -0,0 +1,4 @@
1
+ from .cli import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from typing import Literal, cast
5
+
6
+ from .demo import run_demo
7
+
8
+
9
+ def build_parser() -> argparse.ArgumentParser:
10
+ parser = argparse.ArgumentParser(description="AIC alignment risk template")
11
+ sub = parser.add_subparsers(dest="command")
12
+
13
+ demo = sub.add_parser("demo", help="run a synthetic end-to-end diagnostic")
14
+ demo.add_argument("--output-dir", default="artifacts", help="where to save plots")
15
+ demo.add_argument(
16
+ "--mode",
17
+ default="full",
18
+ choices=["full", "lora"],
19
+ help="analysis mode: full fine-tuning weights or LoRA adapter-only weights",
20
+ )
21
+
22
+ return parser
23
+
24
+
25
+ def main() -> None:
26
+ parser = build_parser()
27
+ args = parser.parse_args()
28
+
29
+ if args.command is None:
30
+ parser.print_help()
31
+ return
32
+
33
+ if args.command == "demo":
34
+ output_dir = getattr(args, "output_dir", "artifacts")
35
+ mode = cast(str, getattr(args, "mode", "full"))
36
+ run_demo(output_dir=output_dir, mode=cast(Literal["full", "lora"], mode))
37
+ return
38
+
39
+ parser.error(f"Unknown command: {args.command}")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
@@ -0,0 +1,115 @@
1
+ """Curvature coupling diagnostics for alignment risk.
2
+
3
+ Academic grounding:
4
+ - [AIC-2026] https://arxiv.org/pdf/2602.15799
5
+
6
+ See docs/SOURCES.md for section/page-level mapping to this module.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Callable, Iterable
13
+
14
+ import torch
15
+
16
+ from .types import CurvatureCouplingResult, SensitivitySubspace
17
+ from .utils import flatten_tensors, move_to_device, named_trainable_parameters, resolve_device
18
+
19
+ LossFn = Callable[[torch.nn.Module, object], torch.Tensor]
20
+
21
+
22
+ @dataclass
23
+ class CurvatureConfig:
24
+ device: str = "auto"
25
+ max_batches: int = 1
26
+ force_eval: bool = True
27
+
28
+
29
+ class CurvatureCouplingAnalyzer:
30
+ """Estimate AIC Condition 3 via directional derivative H g."""
31
+
32
+ def __init__(self, config: CurvatureConfig | None = None):
33
+ self.config = config or CurvatureConfig()
34
+
35
+ def analyze(
36
+ self,
37
+ model: torch.nn.Module,
38
+ dataloader: Iterable[object],
39
+ loss_fn: LossFn,
40
+ subspace: SensitivitySubspace,
41
+ *,
42
+ max_batches_override: int | None = None,
43
+ ) -> CurvatureCouplingResult:
44
+ device = resolve_device(self.config.device)
45
+ model = model.to(device)
46
+ max_batches = max_batches_override if max_batches_override is not None else self.config.max_batches
47
+
48
+ selected_names = [p.name for p in subspace.parameter_slices]
49
+ _, params = named_trainable_parameters(model, include_names=selected_names)
50
+ if not params:
51
+ raise ValueError("No trainable parameters found for curvature analysis.")
52
+
53
+ was_training = model.training
54
+ if self.config.force_eval:
55
+ model.eval()
56
+ try:
57
+ loss = self._mean_loss(model, dataloader, loss_fn, device, max_batches=max_batches)
58
+ finally:
59
+ if self.config.force_eval and was_training:
60
+ model.train()
61
+
62
+ grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
63
+ g = flatten_tensors(grads, params)
64
+
65
+ # Directional derivative of the gradient field along itself: (∇g)g = H g.
66
+ dot = torch.dot(g, g.detach())
67
+ hvp = torch.autograd.grad(dot, params, allow_unused=True)
68
+
69
+ basis = subspace.fisher_eigenvectors.detach().cpu()
70
+ eigvals = torch.clamp(subspace.fisher_eigenvalues.detach().cpu(), min=0.0).to(dtype=basis.dtype)
71
+
72
+ g_vec = g.detach().to(device=basis.device, dtype=basis.dtype)
73
+ a_vec = flatten_tensors(hvp, params).detach().to(device=basis.device, dtype=basis.dtype)
74
+
75
+ g_coeff = basis.T @ g_vec
76
+ epsilon_hat = float(torch.sqrt(torch.sum(eigvals * (g_coeff ** 2))).item())
77
+
78
+ a_coeff = basis.T @ a_vec
79
+ projected_a = basis @ a_coeff
80
+
81
+ gamma_hat = float(torch.sqrt(torch.sum(eigvals * (a_coeff ** 2))).item())
82
+ projected_acc_norm = float(projected_a.norm().item())
83
+
84
+ return CurvatureCouplingResult(
85
+ gamma_hat=gamma_hat,
86
+ epsilon_hat=epsilon_hat,
87
+ acceleration_norm=float(a_vec.norm().item()),
88
+ projected_acceleration_norm=projected_acc_norm,
89
+ )
90
+
91
+ def _mean_loss(
92
+ self,
93
+ model: torch.nn.Module,
94
+ dataloader: Iterable[object],
95
+ loss_fn: LossFn,
96
+ device: torch.device,
97
+ *,
98
+ max_batches: int,
99
+ ) -> torch.Tensor:
100
+ total = None
101
+ count = 0
102
+ for i, batch in enumerate(dataloader):
103
+ if i >= max_batches:
104
+ break
105
+ batch = move_to_device(batch, device)
106
+ loss = loss_fn(model, batch)
107
+ if loss.ndim > 0:
108
+ loss = loss.mean()
109
+ total = loss if total is None else total + loss
110
+ count += 1
111
+
112
+ if count == 0 or total is None:
113
+ raise ValueError("Fine-tuning dataloader yielded no batches.")
114
+
115
+ return total / float(count)