invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,154 @@
1
+ """
2
+ InvarLock Core Types
3
+ ================
4
+
5
+ Core type definitions and enums used throughout InvarLock.
6
+ Torch-independent type system for cross-module compatibility.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import Any, NamedTuple
12
+
13
+
14
+ class EditType(Enum):
15
+ """Types of model edits supported by InvarLock."""
16
+
17
+ QUANTIZATION = "quantization"
18
+ SPARSITY = "sparsity"
19
+ MIXED = "mixed"
20
+
21
+
22
+ class GuardType(Enum):
23
+ """Types of safety guards available."""
24
+
25
+ INVARIANTS = "invariants"
26
+ SPECTRAL = "spectral"
27
+ VARIANCE = "variance"
28
+ RMT = "rmt"
29
+ NOOP = "noop"
30
+
31
+
32
+ class RunStatus(Enum):
33
+ """Execution status for pipeline runs."""
34
+
35
+ PENDING = "pending"
36
+ RUNNING = "running"
37
+ SUCCESS = "success"
38
+ FAILED = "failed"
39
+ ROLLBACK = "rollback"
40
+ CANCELLED = "cancelled"
41
+
42
+
43
+ class LogLevel(Enum):
44
+ """Logging levels for events."""
45
+
46
+ DEBUG = "DEBUG"
47
+ INFO = "INFO"
48
+ WARNING = "WARNING"
49
+ ERROR = "ERROR"
50
+ CRITICAL = "CRITICAL"
51
+
52
+
53
+ @dataclass
54
+ class ModelInfo:
55
+ """Basic model information."""
56
+
57
+ model_id: str
58
+ architecture: str
59
+ parameters: int
60
+ device: str
61
+ precision: str = "float32"
62
+
63
+
64
+ @dataclass
65
+ class EditInfo:
66
+ """Information about an applied edit."""
67
+
68
+ name: str
69
+ type: EditType
70
+ parameters: dict[str, Any]
71
+ compression_ratio: float | None = None
72
+ target_metrics: dict[str, float] | None = None
73
+
74
+
75
+ @dataclass
76
+ class GuardResult:
77
+ """Result from a guard validation."""
78
+
79
+ guard_name: str
80
+ passed: bool
81
+ score: float | None = None
82
+ threshold: float | None = None
83
+ message: str | None = None
84
+ details: dict[str, Any] | None = None
85
+
86
+
87
+ class ValidationResult(NamedTuple):
88
+ """Result from validation operations."""
89
+
90
+ passed: bool
91
+ score: float
92
+ threshold: float
93
+ message: str = ""
94
+
95
+
96
+ @dataclass
97
+ class GuardOutcome:
98
+ """Result from a guard execution."""
99
+
100
+ name: str
101
+ passed: bool
102
+ action: str = "none"
103
+ violations: list[dict[str, Any]] | None = None
104
+ metrics: dict[str, Any] | None = None
105
+
106
+ def __post_init__(self):
107
+ if self.violations is None:
108
+ self.violations = []
109
+ if self.metrics is None:
110
+ self.metrics = {}
111
+
112
+
113
+ @dataclass
114
+ class PolicyConfig:
115
+ """Configuration for guard policies."""
116
+
117
+ on_violation: str = "warn"
118
+ guard_overrides: dict[str, str] | None = None
119
+ enable_auto_rollback: bool = False
120
+
121
+ def __post_init__(self):
122
+ if self.guard_overrides is None:
123
+ self.guard_overrides = {}
124
+
125
+ def get_action_for_guard(self, guard_name: str, requested_action: str) -> str:
126
+ """Get the action for a specific guard."""
127
+ # Check if there's an override for this guard
128
+ if self.guard_overrides and guard_name in self.guard_overrides:
129
+ return self.guard_overrides[guard_name]
130
+
131
+ # If requested action is not 'none', use it
132
+ if requested_action != "none":
133
+ return requested_action
134
+
135
+ # Fall back to global default
136
+ return self.on_violation
137
+
138
+
139
+ def get_worst_action(actions: list[str]) -> str:
140
+ """Get the worst (most severe) action from a list."""
141
+ action_priority = {"none": 0, "warn": 1, "rollback": 2, "abort": 3}
142
+
143
+ if not actions:
144
+ return "none"
145
+
146
+ return max(actions, key=lambda action: action_priority.get(action, 0))
147
+
148
+
149
+ # Type aliases for clarity
150
+ DeviceSpec = str | Any # Device specification
151
+ ConfigDict = dict[str, Any] # Configuration dictionary
152
+ MetricsDict = dict[str, float | int | str | bool] # Metrics
153
+ LayerIndex = int # Layer index
154
+ HeadIndex = int # Attention head index
@@ -0,0 +1,12 @@
1
+ """Edit namespace (`invarlock.edits`) re-exporting built-in edits."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from invarlock.core.abi import INVARLOCK_CORE_ABI as INVARLOCK_CORE_ABI
6
+
7
+ from .quant_rtn import RTNQuantEdit
8
+
9
+ __all__ = [
10
+ "RTNQuantEdit",
11
+ "INVARLOCK_CORE_ABI",
12
+ ]
@@ -0,0 +1,249 @@
1
+ """
2
+ InvarLock Edit Utilities
3
+ ====================
4
+
5
+ Shared helper functions for edit implementations.
6
+ Common functionality used across multiple edit types.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import warnings
12
+ from typing import Any
13
+
14
+ try:
15
+ import torch.nn as nn
16
+
17
+ TORCH_AVAILABLE = True
18
+ except ImportError:
19
+ TORCH_AVAILABLE = False
20
+
21
+ __all__ = [
22
+ "validate_model_structure",
23
+ "calculate_compression_ratio",
24
+ "get_layer_info",
25
+ "validate_edit_parameters",
26
+ "create_edit_metadata",
27
+ ]
28
+
29
+
30
+ def validate_model_structure(
31
+ model_desc: dict[str, Any], required_fields: list[str]
32
+ ) -> bool:
33
+ """
34
+ Validate that a model description has required fields for an edit.
35
+
36
+ Args:
37
+ model_desc: Model description from adapter
38
+ required_fields: List of required field names
39
+
40
+ Returns:
41
+ True if all required fields are present
42
+ """
43
+ missing_fields = []
44
+
45
+ for field in required_fields:
46
+ if field not in model_desc:
47
+ missing_fields.append(field)
48
+
49
+ if missing_fields:
50
+ warnings.warn(
51
+ f"Model missing required fields for edit: {missing_fields}", stacklevel=2
52
+ )
53
+ return False
54
+
55
+ return True
56
+
57
+
58
+ def calculate_compression_ratio(original_params: int, final_params: int) -> float:
59
+ """
60
+ Calculate parameter compression ratio.
61
+
62
+ Args:
63
+ original_params: Original parameter count
64
+ final_params: Final parameter count after edit
65
+
66
+ Returns:
67
+ Compression ratio (final/original)
68
+ """
69
+ if original_params == 0:
70
+ return 1.0
71
+
72
+ return final_params / original_params
73
+
74
+
75
+ def get_layer_info(model_desc: dict[str, Any], layer_idx: int) -> dict[str, Any]:
76
+ """
77
+ Extract information for a specific layer.
78
+
79
+ Args:
80
+ model_desc: Model description from adapter
81
+ layer_idx: Index of the layer
82
+
83
+ Returns:
84
+ Dictionary with layer-specific information
85
+ """
86
+ n_layers = model_desc.get("n_layer", 0)
87
+
88
+ if layer_idx >= n_layers:
89
+ raise IndexError(f"Layer index {layer_idx} >= {n_layers}")
90
+
91
+ heads_per_layer = model_desc.get("heads_per_layer", [])
92
+ mlp_dims = model_desc.get("mlp_dims", [])
93
+
94
+ layer_info = {
95
+ "layer_idx": layer_idx,
96
+ "n_heads": heads_per_layer[layer_idx]
97
+ if layer_idx < len(heads_per_layer)
98
+ else None,
99
+ "mlp_dim": mlp_dims[layer_idx] if layer_idx < len(mlp_dims) else None,
100
+ "hidden_size": model_desc.get("hidden_size"),
101
+ }
102
+
103
+ return layer_info
104
+
105
+
106
+ def validate_edit_parameters(
107
+ params: dict[str, Any],
108
+ required_params: list[str],
109
+ optional_params: dict[str, Any] | None = None,
110
+ ) -> tuple[bool, str]:
111
+ """
112
+ Validate edit parameters.
113
+
114
+ Args:
115
+ params: Parameters to validate
116
+ required_params: List of required parameter names
117
+ optional_params: Dict of optional parameters with default values
118
+
119
+ Returns:
120
+ (is_valid, error_message)
121
+ """
122
+ # Check required parameters
123
+ missing_required = []
124
+ for param in required_params:
125
+ if param not in params:
126
+ missing_required.append(param)
127
+
128
+ if missing_required:
129
+ return False, f"Missing required parameters: {missing_required}"
130
+
131
+ # Add default values for optional parameters
132
+ if optional_params:
133
+ for param, default_value in optional_params.items():
134
+ if param not in params:
135
+ params[param] = default_value
136
+
137
+ return True, "Parameters valid"
138
+
139
+
140
+ def create_edit_metadata(
141
+ edit_name: str,
142
+ model_desc: dict[str, Any],
143
+ parameters: dict[str, Any],
144
+ result: dict[str, Any],
145
+ ) -> dict[str, Any]:
146
+ """
147
+ Create standardized edit metadata.
148
+
149
+ Args:
150
+ edit_name: Name of the edit operation
151
+ model_desc: Original model description
152
+ parameters: Edit parameters used
153
+ result: Edit operation results
154
+
155
+ Returns:
156
+ Standardized metadata dictionary
157
+ """
158
+ original_params = model_desc.get("total_params", 0)
159
+ final_params = result.get("final_params", original_params)
160
+
161
+ metadata: dict[str, Any] = {
162
+ "name": edit_name,
163
+ "parameters": parameters.copy(),
164
+ "original_model": {
165
+ "n_layer": model_desc.get("n_layer", 0),
166
+ "total_params": original_params,
167
+ "model_type": model_desc.get("model_type", "unknown"),
168
+ },
169
+ "results": {
170
+ "final_params": final_params,
171
+ "compression_ratio": calculate_compression_ratio(
172
+ original_params, final_params
173
+ ),
174
+ "layers_modified": result.get("layers_modified", []),
175
+ "success": result.get("success", True),
176
+ },
177
+ }
178
+
179
+ # Add edit-specific results
180
+ results_dict = metadata["results"]
181
+ assert isinstance(results_dict, dict)
182
+ for key, value in result.items():
183
+ if key not in results_dict:
184
+ results_dict[key] = value
185
+
186
+ return metadata
187
+
188
+
189
+ def get_module_parameter_count(module) -> int:
190
+ """
191
+ Count parameters in a module (if torch is available).
192
+
193
+ Args:
194
+ module: PyTorch module
195
+
196
+ Returns:
197
+ Number of parameters
198
+ """
199
+ if not TORCH_AVAILABLE:
200
+ return 0
201
+
202
+ if not isinstance(module, nn.Module):
203
+ return 0
204
+
205
+ return sum(p.numel() for p in module.parameters())
206
+
207
+
208
+ def validate_compression_target(
209
+ target_ratio: Any, min_ratio: float = 0.1, max_ratio: float = 1.0
210
+ ) -> bool:
211
+ """
212
+ Validate compression target ratio.
213
+
214
+ Args:
215
+ target_ratio: Desired compression ratio
216
+ min_ratio: Minimum allowed ratio
217
+ max_ratio: Maximum allowed ratio
218
+
219
+ Returns:
220
+ True if ratio is valid
221
+ """
222
+ if not isinstance(target_ratio, int | float):
223
+ return False
224
+
225
+ return min_ratio <= target_ratio <= max_ratio
226
+
227
+
228
+ def create_layer_mask(
229
+ n_layers: int, layers_to_edit: list[int] | None = None
230
+ ) -> list[bool]:
231
+ """
232
+ Create a boolean mask for layers to edit.
233
+
234
+ Args:
235
+ n_layers: Total number of layers
236
+ layers_to_edit: List of layer indices to edit (None = all layers)
237
+
238
+ Returns:
239
+ Boolean mask where True indicates layer should be edited
240
+ """
241
+ if layers_to_edit is None:
242
+ return [True] * n_layers
243
+
244
+ mask = [False] * n_layers
245
+ for layer_idx in layers_to_edit:
246
+ if 0 <= layer_idx < n_layers:
247
+ mask[layer_idx] = True
248
+
249
+ return mask
@@ -0,0 +1,268 @@
1
+ """
2
+ External Utilities for Edit Operations
3
+ =====================================
4
+
5
+ Utilities for integrating with external edit backends and guard chains.
6
+ Provides common functionality for model snapshots, validation, and guard policies.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import warnings
12
+ from typing import Any
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class ExternalBackend:
19
+ """Base class for external edit backends."""
20
+
21
+ def apply_edit(self, model: nn.Module, config: dict, calib: Any) -> dict:
22
+ """Apply edit to model."""
23
+ raise NotImplementedError()
24
+
25
+ def get_edit_info(self, model: nn.Module, config: dict) -> dict:
26
+ """Get edit information without applying."""
27
+ raise NotImplementedError()
28
+
29
+
30
+ class ExternalEditWrapper:
31
+ """Wrapper for external edit implementations."""
32
+
33
+ def __init__(self, backend: ExternalBackend):
34
+ self.backend = backend
35
+
36
+ def preview(self, model: nn.Module, adapter: Any, calib: Any) -> dict:
37
+ """Preview edit operation."""
38
+ return {"plan": {}, "estimated_sparsity": {}, "preview_metrics": {}}
39
+
40
+ def apply(self, model: nn.Module, adapter: Any, plan: dict) -> dict:
41
+ """Apply edit operation."""
42
+ return {"actual_sparsity": {}, "modified_layers": [], "metrics": {}}
43
+
44
+
45
+ def check_dependencies(edit_type: str) -> bool:
46
+ """
47
+ Check if dependencies for edit type are available.
48
+
49
+ Args:
50
+ edit_type: Type of edit (e.g., 'quant')
51
+
52
+ Returns:
53
+ True if dependencies are available
54
+ """
55
+ if edit_type == "svd":
56
+ return False
57
+ elif edit_type == "quant":
58
+ # Quantization uses standard PyTorch
59
+ return True
60
+ else:
61
+ return False
62
+
63
+
64
+ def validate_edit_config(config: dict, edit_type: str) -> dict:
65
+ """
66
+ Validate and normalize edit configuration.
67
+
68
+ Args:
69
+ config: Configuration dictionary
70
+ edit_type: Type of edit operation
71
+
72
+ Returns:
73
+ Validated configuration dictionary
74
+ """
75
+ validated_config = config.copy()
76
+
77
+ if edit_type == "svd":
78
+ raise ValueError("Unsupported edit type: svd")
79
+
80
+ elif edit_type == "quant":
81
+ # Validate quantization specific config
82
+ bits = validated_config.get("bits", 8)
83
+ if bits not in [4, 8, 16]:
84
+ warnings.warn(
85
+ f"Unusual bit width {bits}, common values are 4, 8, or 16", stacklevel=2
86
+ )
87
+
88
+ return validated_config
89
+
90
+
91
+ def compute_edit_metrics(
92
+ model_before: nn.Module, model_after: nn.Module, config: dict
93
+ ) -> dict:
94
+ """
95
+ Compute metrics comparing model before and after edit.
96
+
97
+ Args:
98
+ model_before: Model state before edit
99
+ model_after: Model state after edit
100
+ config: Edit configuration
101
+
102
+ Returns:
103
+ Dictionary of computed metrics
104
+ """
105
+ # Count parameters
106
+ params_before = sum(p.numel() for p in model_before.parameters())
107
+ params_after = sum(p.numel() for p in model_after.parameters())
108
+
109
+ # Count non-zero parameters (for sparsity)
110
+ nonzero_before = sum((p != 0).sum().item() for p in model_before.parameters())
111
+ nonzero_after = sum((p != 0).sum().item() for p in model_after.parameters())
112
+
113
+ # Calculate compression ratios
114
+ param_ratio = params_after / params_before if params_before > 0 else 1.0
115
+ sparsity_ratio = (
116
+ 1.0 - (nonzero_after / nonzero_before) if nonzero_before > 0 else 0.0
117
+ )
118
+
119
+ return {
120
+ "params_before": params_before,
121
+ "params_after": params_after,
122
+ "nonzero_before": nonzero_before,
123
+ "nonzero_after": nonzero_after,
124
+ "param_compression_ratio": param_ratio,
125
+ "sparsity_achieved": sparsity_ratio,
126
+ "memory_saved_mb": (params_before - params_after) * 4 / (1024 * 1024),
127
+ }
128
+
129
+
130
+ def prepare_calibration_data(calib_data: Any, config: dict) -> Any:
131
+ """
132
+ Prepare calibration data for edit operations.
133
+
134
+ Args:
135
+ calib_data: Raw calibration data
136
+ config: Edit configuration
137
+
138
+ Returns:
139
+ Prepared calibration data
140
+ """
141
+ if calib_data is None:
142
+ return None
143
+
144
+ # If it's already prepared (has __iter__ and __len__), return as-is
145
+ if hasattr(calib_data, "__iter__") and hasattr(calib_data, "__len__"):
146
+ return calib_data
147
+
148
+ # If it's a list, return as-is
149
+ if isinstance(calib_data, list):
150
+ return calib_data
151
+
152
+ # If it's a tensor, wrap in a list
153
+ if isinstance(calib_data, torch.Tensor):
154
+ return [calib_data]
155
+
156
+ return calib_data
157
+
158
+
159
+ def safe_model_snapshot(model: nn.Module, adapter: Any) -> dict:
160
+ """
161
+ Create a safe snapshot of model state for potential rollback.
162
+
163
+ Args:
164
+ model: Model to snapshot
165
+ adapter: Model adapter for model-specific operations
166
+
167
+ Returns:
168
+ Model snapshot dictionary
169
+ """
170
+ try:
171
+ # Create deep copy of state dict for safety
172
+ state_dict = {}
173
+ for name, param in model.named_parameters():
174
+ if param.requires_grad:
175
+ state_dict[name] = param.data.clone()
176
+
177
+ # Store model configuration if available through adapter
178
+ model_config = {}
179
+ if hasattr(adapter, "get_config"):
180
+ try:
181
+ model_config = adapter.get_config(model)
182
+ except Exception:
183
+ pass
184
+
185
+ return {
186
+ "state_dict": state_dict,
187
+ "config": model_config,
188
+ "snapshot_time": torch.cuda.Event(enable_timing=True)
189
+ if torch.cuda.is_available()
190
+ else None,
191
+ }
192
+ except Exception as e:
193
+ warnings.warn(f"Failed to create model snapshot: {e}", stacklevel=2)
194
+ return {}
195
+
196
+
197
+ def safe_model_restore(model: nn.Module, snapshot: dict, adapter: Any) -> bool:
198
+ """
199
+ Safely restore model from snapshot.
200
+
201
+ Args:
202
+ model: Model to restore
203
+ snapshot: Snapshot from safe_model_snapshot()
204
+ adapter: Model adapter for model-specific operations
205
+
206
+ Returns:
207
+ True if restore was successful
208
+ """
209
+ try:
210
+ if "state_dict" not in snapshot:
211
+ warnings.warn("Snapshot missing state_dict, cannot restore", stacklevel=2)
212
+ return False
213
+
214
+ # Restore parameters
215
+ state_dict = snapshot["state_dict"]
216
+ for name, param in model.named_parameters():
217
+ if name in state_dict:
218
+ param.data.copy_(state_dict[name])
219
+
220
+ return True
221
+ except Exception as e:
222
+ warnings.warn(f"Failed to restore model from snapshot: {e}", stacklevel=2)
223
+ return False
224
+
225
+
226
+ def create_baseline_guard_policy(edit_type: str) -> dict:
227
+ """
228
+ Create baseline guard policy for edit type.
229
+
230
+ Args:
231
+ edit_type: Type of edit operation
232
+
233
+ Returns:
234
+ Guard policy configuration
235
+ """
236
+ baseline_policies = {
237
+ "quant": {
238
+ "spectral_monitoring": False, # Quantization may change spectral properties
239
+ "rmt_detection": True,
240
+ "weight_change_threshold": 0.5, # Higher tolerance for quantization
241
+ "activation_change_threshold": 0.3,
242
+ "max_spectral_norm_increase": 3.0,
243
+ },
244
+ }
245
+
246
+ return baseline_policies.get(
247
+ edit_type,
248
+ {
249
+ "spectral_monitoring": True,
250
+ "rmt_detection": True,
251
+ "weight_change_threshold": 0.1,
252
+ "activation_change_threshold": 0.1,
253
+ "max_spectral_norm_increase": 2.0,
254
+ },
255
+ )
256
+
257
+
258
+ __all__ = [
259
+ "ExternalBackend",
260
+ "ExternalEditWrapper",
261
+ "check_dependencies",
262
+ "validate_edit_config",
263
+ "compute_edit_metrics",
264
+ "prepare_calibration_data",
265
+ "safe_model_snapshot",
266
+ "safe_model_restore",
267
+ "create_baseline_guard_policy",
268
+ ]