invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/core/types.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Core Types
|
|
3
|
+
================
|
|
4
|
+
|
|
5
|
+
Core type definitions and enums used throughout InvarLock.
|
|
6
|
+
Torch-independent type system for cross-module compatibility.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any, NamedTuple
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EditType(Enum):
|
|
15
|
+
"""Types of model edits supported by InvarLock."""
|
|
16
|
+
|
|
17
|
+
QUANTIZATION = "quantization"
|
|
18
|
+
SPARSITY = "sparsity"
|
|
19
|
+
MIXED = "mixed"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GuardType(Enum):
|
|
23
|
+
"""Types of safety guards available."""
|
|
24
|
+
|
|
25
|
+
INVARIANTS = "invariants"
|
|
26
|
+
SPECTRAL = "spectral"
|
|
27
|
+
VARIANCE = "variance"
|
|
28
|
+
RMT = "rmt"
|
|
29
|
+
NOOP = "noop"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RunStatus(Enum):
|
|
33
|
+
"""Execution status for pipeline runs."""
|
|
34
|
+
|
|
35
|
+
PENDING = "pending"
|
|
36
|
+
RUNNING = "running"
|
|
37
|
+
SUCCESS = "success"
|
|
38
|
+
FAILED = "failed"
|
|
39
|
+
ROLLBACK = "rollback"
|
|
40
|
+
CANCELLED = "cancelled"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LogLevel(Enum):
|
|
44
|
+
"""Logging levels for events."""
|
|
45
|
+
|
|
46
|
+
DEBUG = "DEBUG"
|
|
47
|
+
INFO = "INFO"
|
|
48
|
+
WARNING = "WARNING"
|
|
49
|
+
ERROR = "ERROR"
|
|
50
|
+
CRITICAL = "CRITICAL"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ModelInfo:
|
|
55
|
+
"""Basic model information."""
|
|
56
|
+
|
|
57
|
+
model_id: str
|
|
58
|
+
architecture: str
|
|
59
|
+
parameters: int
|
|
60
|
+
device: str
|
|
61
|
+
precision: str = "float32"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class EditInfo:
|
|
66
|
+
"""Information about an applied edit."""
|
|
67
|
+
|
|
68
|
+
name: str
|
|
69
|
+
type: EditType
|
|
70
|
+
parameters: dict[str, Any]
|
|
71
|
+
compression_ratio: float | None = None
|
|
72
|
+
target_metrics: dict[str, float] | None = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class GuardResult:
|
|
77
|
+
"""Result from a guard validation."""
|
|
78
|
+
|
|
79
|
+
guard_name: str
|
|
80
|
+
passed: bool
|
|
81
|
+
score: float | None = None
|
|
82
|
+
threshold: float | None = None
|
|
83
|
+
message: str | None = None
|
|
84
|
+
details: dict[str, Any] | None = None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ValidationResult(NamedTuple):
|
|
88
|
+
"""Result from validation operations."""
|
|
89
|
+
|
|
90
|
+
passed: bool
|
|
91
|
+
score: float
|
|
92
|
+
threshold: float
|
|
93
|
+
message: str = ""
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class GuardOutcome:
|
|
98
|
+
"""Result from a guard execution."""
|
|
99
|
+
|
|
100
|
+
name: str
|
|
101
|
+
passed: bool
|
|
102
|
+
action: str = "none"
|
|
103
|
+
violations: list[dict[str, Any]] | None = None
|
|
104
|
+
metrics: dict[str, Any] | None = None
|
|
105
|
+
|
|
106
|
+
def __post_init__(self):
|
|
107
|
+
if self.violations is None:
|
|
108
|
+
self.violations = []
|
|
109
|
+
if self.metrics is None:
|
|
110
|
+
self.metrics = {}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class PolicyConfig:
|
|
115
|
+
"""Configuration for guard policies."""
|
|
116
|
+
|
|
117
|
+
on_violation: str = "warn"
|
|
118
|
+
guard_overrides: dict[str, str] | None = None
|
|
119
|
+
enable_auto_rollback: bool = False
|
|
120
|
+
|
|
121
|
+
def __post_init__(self):
|
|
122
|
+
if self.guard_overrides is None:
|
|
123
|
+
self.guard_overrides = {}
|
|
124
|
+
|
|
125
|
+
def get_action_for_guard(self, guard_name: str, requested_action: str) -> str:
|
|
126
|
+
"""Get the action for a specific guard."""
|
|
127
|
+
# Check if there's an override for this guard
|
|
128
|
+
if self.guard_overrides and guard_name in self.guard_overrides:
|
|
129
|
+
return self.guard_overrides[guard_name]
|
|
130
|
+
|
|
131
|
+
# If requested action is not 'none', use it
|
|
132
|
+
if requested_action != "none":
|
|
133
|
+
return requested_action
|
|
134
|
+
|
|
135
|
+
# Fall back to global default
|
|
136
|
+
return self.on_violation
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def get_worst_action(actions: list[str]) -> str:
|
|
140
|
+
"""Get the worst (most severe) action from a list."""
|
|
141
|
+
action_priority = {"none": 0, "warn": 1, "rollback": 2, "abort": 3}
|
|
142
|
+
|
|
143
|
+
if not actions:
|
|
144
|
+
return "none"
|
|
145
|
+
|
|
146
|
+
return max(actions, key=lambda action: action_priority.get(action, 0))
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# Type aliases for clarity
|
|
150
|
+
DeviceSpec = str | Any # Device specification
|
|
151
|
+
ConfigDict = dict[str, Any] # Configuration dictionary
|
|
152
|
+
MetricsDict = dict[str, float | int | str | bool] # Metrics
|
|
153
|
+
LayerIndex = int # Layer index
|
|
154
|
+
HeadIndex = int # Attention head index
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Edit namespace (`invarlock.edits`) re-exporting built-in edits."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from invarlock.core.abi import INVARLOCK_CORE_ABI as INVARLOCK_CORE_ABI
|
|
6
|
+
|
|
7
|
+
from .quant_rtn import RTNQuantEdit
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"RTNQuantEdit",
|
|
11
|
+
"INVARLOCK_CORE_ABI",
|
|
12
|
+
]
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Edit Utilities
|
|
3
|
+
====================
|
|
4
|
+
|
|
5
|
+
Shared helper functions for edit implementations.
|
|
6
|
+
Common functionality used across multiple edit types.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
TORCH_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
TORCH_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"validate_model_structure",
|
|
23
|
+
"calculate_compression_ratio",
|
|
24
|
+
"get_layer_info",
|
|
25
|
+
"validate_edit_parameters",
|
|
26
|
+
"create_edit_metadata",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def validate_model_structure(
|
|
31
|
+
model_desc: dict[str, Any], required_fields: list[str]
|
|
32
|
+
) -> bool:
|
|
33
|
+
"""
|
|
34
|
+
Validate that a model description has required fields for an edit.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model_desc: Model description from adapter
|
|
38
|
+
required_fields: List of required field names
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
True if all required fields are present
|
|
42
|
+
"""
|
|
43
|
+
missing_fields = []
|
|
44
|
+
|
|
45
|
+
for field in required_fields:
|
|
46
|
+
if field not in model_desc:
|
|
47
|
+
missing_fields.append(field)
|
|
48
|
+
|
|
49
|
+
if missing_fields:
|
|
50
|
+
warnings.warn(
|
|
51
|
+
f"Model missing required fields for edit: {missing_fields}", stacklevel=2
|
|
52
|
+
)
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
return True
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def calculate_compression_ratio(original_params: int, final_params: int) -> float:
|
|
59
|
+
"""
|
|
60
|
+
Calculate parameter compression ratio.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
original_params: Original parameter count
|
|
64
|
+
final_params: Final parameter count after edit
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Compression ratio (final/original)
|
|
68
|
+
"""
|
|
69
|
+
if original_params == 0:
|
|
70
|
+
return 1.0
|
|
71
|
+
|
|
72
|
+
return final_params / original_params
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_layer_info(model_desc: dict[str, Any], layer_idx: int) -> dict[str, Any]:
|
|
76
|
+
"""
|
|
77
|
+
Extract information for a specific layer.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
model_desc: Model description from adapter
|
|
81
|
+
layer_idx: Index of the layer
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Dictionary with layer-specific information
|
|
85
|
+
"""
|
|
86
|
+
n_layers = model_desc.get("n_layer", 0)
|
|
87
|
+
|
|
88
|
+
if layer_idx >= n_layers:
|
|
89
|
+
raise IndexError(f"Layer index {layer_idx} >= {n_layers}")
|
|
90
|
+
|
|
91
|
+
heads_per_layer = model_desc.get("heads_per_layer", [])
|
|
92
|
+
mlp_dims = model_desc.get("mlp_dims", [])
|
|
93
|
+
|
|
94
|
+
layer_info = {
|
|
95
|
+
"layer_idx": layer_idx,
|
|
96
|
+
"n_heads": heads_per_layer[layer_idx]
|
|
97
|
+
if layer_idx < len(heads_per_layer)
|
|
98
|
+
else None,
|
|
99
|
+
"mlp_dim": mlp_dims[layer_idx] if layer_idx < len(mlp_dims) else None,
|
|
100
|
+
"hidden_size": model_desc.get("hidden_size"),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
return layer_info
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def validate_edit_parameters(
|
|
107
|
+
params: dict[str, Any],
|
|
108
|
+
required_params: list[str],
|
|
109
|
+
optional_params: dict[str, Any] | None = None,
|
|
110
|
+
) -> tuple[bool, str]:
|
|
111
|
+
"""
|
|
112
|
+
Validate edit parameters.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
params: Parameters to validate
|
|
116
|
+
required_params: List of required parameter names
|
|
117
|
+
optional_params: Dict of optional parameters with default values
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
(is_valid, error_message)
|
|
121
|
+
"""
|
|
122
|
+
# Check required parameters
|
|
123
|
+
missing_required = []
|
|
124
|
+
for param in required_params:
|
|
125
|
+
if param not in params:
|
|
126
|
+
missing_required.append(param)
|
|
127
|
+
|
|
128
|
+
if missing_required:
|
|
129
|
+
return False, f"Missing required parameters: {missing_required}"
|
|
130
|
+
|
|
131
|
+
# Add default values for optional parameters
|
|
132
|
+
if optional_params:
|
|
133
|
+
for param, default_value in optional_params.items():
|
|
134
|
+
if param not in params:
|
|
135
|
+
params[param] = default_value
|
|
136
|
+
|
|
137
|
+
return True, "Parameters valid"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def create_edit_metadata(
|
|
141
|
+
edit_name: str,
|
|
142
|
+
model_desc: dict[str, Any],
|
|
143
|
+
parameters: dict[str, Any],
|
|
144
|
+
result: dict[str, Any],
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
"""
|
|
147
|
+
Create standardized edit metadata.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
edit_name: Name of the edit operation
|
|
151
|
+
model_desc: Original model description
|
|
152
|
+
parameters: Edit parameters used
|
|
153
|
+
result: Edit operation results
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Standardized metadata dictionary
|
|
157
|
+
"""
|
|
158
|
+
original_params = model_desc.get("total_params", 0)
|
|
159
|
+
final_params = result.get("final_params", original_params)
|
|
160
|
+
|
|
161
|
+
metadata: dict[str, Any] = {
|
|
162
|
+
"name": edit_name,
|
|
163
|
+
"parameters": parameters.copy(),
|
|
164
|
+
"original_model": {
|
|
165
|
+
"n_layer": model_desc.get("n_layer", 0),
|
|
166
|
+
"total_params": original_params,
|
|
167
|
+
"model_type": model_desc.get("model_type", "unknown"),
|
|
168
|
+
},
|
|
169
|
+
"results": {
|
|
170
|
+
"final_params": final_params,
|
|
171
|
+
"compression_ratio": calculate_compression_ratio(
|
|
172
|
+
original_params, final_params
|
|
173
|
+
),
|
|
174
|
+
"layers_modified": result.get("layers_modified", []),
|
|
175
|
+
"success": result.get("success", True),
|
|
176
|
+
},
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
# Add edit-specific results
|
|
180
|
+
results_dict = metadata["results"]
|
|
181
|
+
assert isinstance(results_dict, dict)
|
|
182
|
+
for key, value in result.items():
|
|
183
|
+
if key not in results_dict:
|
|
184
|
+
results_dict[key] = value
|
|
185
|
+
|
|
186
|
+
return metadata
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_module_parameter_count(module) -> int:
|
|
190
|
+
"""
|
|
191
|
+
Count parameters in a module (if torch is available).
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
module: PyTorch module
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Number of parameters
|
|
198
|
+
"""
|
|
199
|
+
if not TORCH_AVAILABLE:
|
|
200
|
+
return 0
|
|
201
|
+
|
|
202
|
+
if not isinstance(module, nn.Module):
|
|
203
|
+
return 0
|
|
204
|
+
|
|
205
|
+
return sum(p.numel() for p in module.parameters())
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def validate_compression_target(
|
|
209
|
+
target_ratio: Any, min_ratio: float = 0.1, max_ratio: float = 1.0
|
|
210
|
+
) -> bool:
|
|
211
|
+
"""
|
|
212
|
+
Validate compression target ratio.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
target_ratio: Desired compression ratio
|
|
216
|
+
min_ratio: Minimum allowed ratio
|
|
217
|
+
max_ratio: Maximum allowed ratio
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
True if ratio is valid
|
|
221
|
+
"""
|
|
222
|
+
if not isinstance(target_ratio, int | float):
|
|
223
|
+
return False
|
|
224
|
+
|
|
225
|
+
return min_ratio <= target_ratio <= max_ratio
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def create_layer_mask(
|
|
229
|
+
n_layers: int, layers_to_edit: list[int] | None = None
|
|
230
|
+
) -> list[bool]:
|
|
231
|
+
"""
|
|
232
|
+
Create a boolean mask for layers to edit.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
n_layers: Total number of layers
|
|
236
|
+
layers_to_edit: List of layer indices to edit (None = all layers)
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Boolean mask where True indicates layer should be edited
|
|
240
|
+
"""
|
|
241
|
+
if layers_to_edit is None:
|
|
242
|
+
return [True] * n_layers
|
|
243
|
+
|
|
244
|
+
mask = [False] * n_layers
|
|
245
|
+
for layer_idx in layers_to_edit:
|
|
246
|
+
if 0 <= layer_idx < n_layers:
|
|
247
|
+
mask[layer_idx] = True
|
|
248
|
+
|
|
249
|
+
return mask
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""
|
|
2
|
+
External Utilities for Edit Operations
|
|
3
|
+
=====================================
|
|
4
|
+
|
|
5
|
+
Utilities for integrating with external edit backends and guard chains.
|
|
6
|
+
Provides common functionality for model snapshots, validation, and guard policies.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ExternalBackend:
|
|
19
|
+
"""Base class for external edit backends."""
|
|
20
|
+
|
|
21
|
+
def apply_edit(self, model: nn.Module, config: dict, calib: Any) -> dict:
|
|
22
|
+
"""Apply edit to model."""
|
|
23
|
+
raise NotImplementedError()
|
|
24
|
+
|
|
25
|
+
def get_edit_info(self, model: nn.Module, config: dict) -> dict:
|
|
26
|
+
"""Get edit information without applying."""
|
|
27
|
+
raise NotImplementedError()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ExternalEditWrapper:
|
|
31
|
+
"""Wrapper for external edit implementations."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, backend: ExternalBackend):
|
|
34
|
+
self.backend = backend
|
|
35
|
+
|
|
36
|
+
def preview(self, model: nn.Module, adapter: Any, calib: Any) -> dict:
|
|
37
|
+
"""Preview edit operation."""
|
|
38
|
+
return {"plan": {}, "estimated_sparsity": {}, "preview_metrics": {}}
|
|
39
|
+
|
|
40
|
+
def apply(self, model: nn.Module, adapter: Any, plan: dict) -> dict:
|
|
41
|
+
"""Apply edit operation."""
|
|
42
|
+
return {"actual_sparsity": {}, "modified_layers": [], "metrics": {}}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def check_dependencies(edit_type: str) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
Check if dependencies for edit type are available.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
edit_type: Type of edit (e.g., 'quant')
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
True if dependencies are available
|
|
54
|
+
"""
|
|
55
|
+
if edit_type == "svd":
|
|
56
|
+
return False
|
|
57
|
+
elif edit_type == "quant":
|
|
58
|
+
# Quantization uses standard PyTorch
|
|
59
|
+
return True
|
|
60
|
+
else:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def validate_edit_config(config: dict, edit_type: str) -> dict:
|
|
65
|
+
"""
|
|
66
|
+
Validate and normalize edit configuration.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config: Configuration dictionary
|
|
70
|
+
edit_type: Type of edit operation
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Validated configuration dictionary
|
|
74
|
+
"""
|
|
75
|
+
validated_config = config.copy()
|
|
76
|
+
|
|
77
|
+
if edit_type == "svd":
|
|
78
|
+
raise ValueError("Unsupported edit type: svd")
|
|
79
|
+
|
|
80
|
+
elif edit_type == "quant":
|
|
81
|
+
# Validate quantization specific config
|
|
82
|
+
bits = validated_config.get("bits", 8)
|
|
83
|
+
if bits not in [4, 8, 16]:
|
|
84
|
+
warnings.warn(
|
|
85
|
+
f"Unusual bit width {bits}, common values are 4, 8, or 16", stacklevel=2
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return validated_config
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def compute_edit_metrics(
|
|
92
|
+
model_before: nn.Module, model_after: nn.Module, config: dict
|
|
93
|
+
) -> dict:
|
|
94
|
+
"""
|
|
95
|
+
Compute metrics comparing model before and after edit.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
model_before: Model state before edit
|
|
99
|
+
model_after: Model state after edit
|
|
100
|
+
config: Edit configuration
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Dictionary of computed metrics
|
|
104
|
+
"""
|
|
105
|
+
# Count parameters
|
|
106
|
+
params_before = sum(p.numel() for p in model_before.parameters())
|
|
107
|
+
params_after = sum(p.numel() for p in model_after.parameters())
|
|
108
|
+
|
|
109
|
+
# Count non-zero parameters (for sparsity)
|
|
110
|
+
nonzero_before = sum((p != 0).sum().item() for p in model_before.parameters())
|
|
111
|
+
nonzero_after = sum((p != 0).sum().item() for p in model_after.parameters())
|
|
112
|
+
|
|
113
|
+
# Calculate compression ratios
|
|
114
|
+
param_ratio = params_after / params_before if params_before > 0 else 1.0
|
|
115
|
+
sparsity_ratio = (
|
|
116
|
+
1.0 - (nonzero_after / nonzero_before) if nonzero_before > 0 else 0.0
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return {
|
|
120
|
+
"params_before": params_before,
|
|
121
|
+
"params_after": params_after,
|
|
122
|
+
"nonzero_before": nonzero_before,
|
|
123
|
+
"nonzero_after": nonzero_after,
|
|
124
|
+
"param_compression_ratio": param_ratio,
|
|
125
|
+
"sparsity_achieved": sparsity_ratio,
|
|
126
|
+
"memory_saved_mb": (params_before - params_after) * 4 / (1024 * 1024),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def prepare_calibration_data(calib_data: Any, config: dict) -> Any:
|
|
131
|
+
"""
|
|
132
|
+
Prepare calibration data for edit operations.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
calib_data: Raw calibration data
|
|
136
|
+
config: Edit configuration
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Prepared calibration data
|
|
140
|
+
"""
|
|
141
|
+
if calib_data is None:
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
# If it's already prepared (has __iter__ and __len__), return as-is
|
|
145
|
+
if hasattr(calib_data, "__iter__") and hasattr(calib_data, "__len__"):
|
|
146
|
+
return calib_data
|
|
147
|
+
|
|
148
|
+
# If it's a list, return as-is
|
|
149
|
+
if isinstance(calib_data, list):
|
|
150
|
+
return calib_data
|
|
151
|
+
|
|
152
|
+
# If it's a tensor, wrap in a list
|
|
153
|
+
if isinstance(calib_data, torch.Tensor):
|
|
154
|
+
return [calib_data]
|
|
155
|
+
|
|
156
|
+
return calib_data
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def safe_model_snapshot(model: nn.Module, adapter: Any) -> dict:
|
|
160
|
+
"""
|
|
161
|
+
Create a safe snapshot of model state for potential rollback.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
model: Model to snapshot
|
|
165
|
+
adapter: Model adapter for model-specific operations
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Model snapshot dictionary
|
|
169
|
+
"""
|
|
170
|
+
try:
|
|
171
|
+
# Create deep copy of state dict for safety
|
|
172
|
+
state_dict = {}
|
|
173
|
+
for name, param in model.named_parameters():
|
|
174
|
+
if param.requires_grad:
|
|
175
|
+
state_dict[name] = param.data.clone()
|
|
176
|
+
|
|
177
|
+
# Store model configuration if available through adapter
|
|
178
|
+
model_config = {}
|
|
179
|
+
if hasattr(adapter, "get_config"):
|
|
180
|
+
try:
|
|
181
|
+
model_config = adapter.get_config(model)
|
|
182
|
+
except Exception:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
return {
|
|
186
|
+
"state_dict": state_dict,
|
|
187
|
+
"config": model_config,
|
|
188
|
+
"snapshot_time": torch.cuda.Event(enable_timing=True)
|
|
189
|
+
if torch.cuda.is_available()
|
|
190
|
+
else None,
|
|
191
|
+
}
|
|
192
|
+
except Exception as e:
|
|
193
|
+
warnings.warn(f"Failed to create model snapshot: {e}", stacklevel=2)
|
|
194
|
+
return {}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def safe_model_restore(model: nn.Module, snapshot: dict, adapter: Any) -> bool:
|
|
198
|
+
"""
|
|
199
|
+
Safely restore model from snapshot.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model: Model to restore
|
|
203
|
+
snapshot: Snapshot from safe_model_snapshot()
|
|
204
|
+
adapter: Model adapter for model-specific operations
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
True if restore was successful
|
|
208
|
+
"""
|
|
209
|
+
try:
|
|
210
|
+
if "state_dict" not in snapshot:
|
|
211
|
+
warnings.warn("Snapshot missing state_dict, cannot restore", stacklevel=2)
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
# Restore parameters
|
|
215
|
+
state_dict = snapshot["state_dict"]
|
|
216
|
+
for name, param in model.named_parameters():
|
|
217
|
+
if name in state_dict:
|
|
218
|
+
param.data.copy_(state_dict[name])
|
|
219
|
+
|
|
220
|
+
return True
|
|
221
|
+
except Exception as e:
|
|
222
|
+
warnings.warn(f"Failed to restore model from snapshot: {e}", stacklevel=2)
|
|
223
|
+
return False
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def create_baseline_guard_policy(edit_type: str) -> dict:
|
|
227
|
+
"""
|
|
228
|
+
Create baseline guard policy for edit type.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
edit_type: Type of edit operation
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Guard policy configuration
|
|
235
|
+
"""
|
|
236
|
+
baseline_policies = {
|
|
237
|
+
"quant": {
|
|
238
|
+
"spectral_monitoring": False, # Quantization may change spectral properties
|
|
239
|
+
"rmt_detection": True,
|
|
240
|
+
"weight_change_threshold": 0.5, # Higher tolerance for quantization
|
|
241
|
+
"activation_change_threshold": 0.3,
|
|
242
|
+
"max_spectral_norm_increase": 3.0,
|
|
243
|
+
},
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
return baseline_policies.get(
|
|
247
|
+
edit_type,
|
|
248
|
+
{
|
|
249
|
+
"spectral_monitoring": True,
|
|
250
|
+
"rmt_detection": True,
|
|
251
|
+
"weight_change_threshold": 0.1,
|
|
252
|
+
"activation_change_threshold": 0.1,
|
|
253
|
+
"max_spectral_norm_increase": 2.0,
|
|
254
|
+
},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
__all__ = [
|
|
259
|
+
"ExternalBackend",
|
|
260
|
+
"ExternalEditWrapper",
|
|
261
|
+
"check_dependencies",
|
|
262
|
+
"validate_edit_config",
|
|
263
|
+
"compute_edit_metrics",
|
|
264
|
+
"prepare_calibration_data",
|
|
265
|
+
"safe_model_snapshot",
|
|
266
|
+
"safe_model_restore",
|
|
267
|
+
"create_baseline_guard_policy",
|
|
268
|
+
]
|