invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/security.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Security Utilities
|
|
3
|
+
========================
|
|
4
|
+
|
|
5
|
+
Runtime hardening helpers used by the CLI and automation surfaces.
|
|
6
|
+
|
|
7
|
+
- Network guard: disables outbound socket connections unless explicitly allowed.
|
|
8
|
+
- Secure temporary directory helper ensuring 0o700 permissions and cleanup.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import contextlib
|
|
14
|
+
import os
|
|
15
|
+
import shutil
|
|
16
|
+
import socket
|
|
17
|
+
import stat
|
|
18
|
+
import tempfile
|
|
19
|
+
from collections.abc import Callable, Iterator
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"NetworkGuard",
|
|
25
|
+
"enforce_network_policy",
|
|
26
|
+
"enforce_default_security",
|
|
27
|
+
"temporarily_allow_network",
|
|
28
|
+
"network_policy_allows",
|
|
29
|
+
"secure_tempdir",
|
|
30
|
+
"is_secure_path",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
_NETWORK_DISABLED_ERROR = RuntimeError(
|
|
34
|
+
"Network access disabled by InvarLock security policy. "
|
|
35
|
+
"Set INVARLOCK_ALLOW_NETWORK=1 if connectivity is required."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class NetworkGuard:
|
|
40
|
+
"""Installable network guard that blocks outbound socket connections."""
|
|
41
|
+
|
|
42
|
+
def __init__(self) -> None:
|
|
43
|
+
self._installed = False
|
|
44
|
+
self._original_socket_cls: type[socket.socket] | None = None
|
|
45
|
+
self._original_create_connection: Callable[..., socket.socket] | None = None
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def installed(self) -> bool:
|
|
49
|
+
"""Whether the guard is currently blocking network access."""
|
|
50
|
+
return self._installed
|
|
51
|
+
|
|
52
|
+
def install(self) -> None:
|
|
53
|
+
"""Install the guard if not already active."""
|
|
54
|
+
if self._installed:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
self._original_socket_cls = socket.socket
|
|
58
|
+
self._original_create_connection = socket.create_connection
|
|
59
|
+
|
|
60
|
+
guard_error = _NETWORK_DISABLED_ERROR
|
|
61
|
+
|
|
62
|
+
class GuardedSocket(socket.socket):
|
|
63
|
+
"""Socket subclass that blocks connect calls."""
|
|
64
|
+
|
|
65
|
+
def connect(self_inner, address: Any) -> None:
|
|
66
|
+
raise guard_error
|
|
67
|
+
|
|
68
|
+
def guarded_create_connection(
|
|
69
|
+
address: Any,
|
|
70
|
+
timeout: float | None = None,
|
|
71
|
+
source_address: Any | None = None,
|
|
72
|
+
) -> socket.socket:
|
|
73
|
+
raise guard_error
|
|
74
|
+
|
|
75
|
+
setattr(socket, "socket", GuardedSocket) # noqa: B010
|
|
76
|
+
setattr(socket, "create_connection", guarded_create_connection) # noqa: B010
|
|
77
|
+
self._installed = True
|
|
78
|
+
|
|
79
|
+
def restore(self) -> None:
|
|
80
|
+
"""Restore the original socket implementations."""
|
|
81
|
+
if not self._installed:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
if self._original_socket_cls is not None:
|
|
85
|
+
setattr(socket, "socket", self._original_socket_cls) # noqa: B010
|
|
86
|
+
if self._original_create_connection is not None:
|
|
87
|
+
setattr(socket, "create_connection", self._original_create_connection) # noqa: B010
|
|
88
|
+
self._installed = False
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
_GUARD = NetworkGuard()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def enforce_network_policy(allow: bool) -> None:
|
|
95
|
+
"""
|
|
96
|
+
Apply the global network policy.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
allow: When True the guard is removed, otherwise network access is blocked.
|
|
100
|
+
"""
|
|
101
|
+
if allow:
|
|
102
|
+
_GUARD.restore()
|
|
103
|
+
else:
|
|
104
|
+
_GUARD.install()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def network_policy_allows() -> bool:
|
|
108
|
+
"""Return True if outbound connections are currently permitted."""
|
|
109
|
+
return not _GUARD.installed
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def enforce_default_security() -> None:
|
|
113
|
+
"""
|
|
114
|
+
Enforce default runtime security posture.
|
|
115
|
+
|
|
116
|
+
Network access is denied unless INVARLOCK_ALLOW_NETWORK is set to a truthy value.
|
|
117
|
+
"""
|
|
118
|
+
allow_env = os.environ.get("INVARLOCK_ALLOW_NETWORK", "")
|
|
119
|
+
allow_network = allow_env.strip().lower() in {"1", "true", "yes", "on"}
|
|
120
|
+
enforce_network_policy(allow_network)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@contextlib.contextmanager
|
|
124
|
+
def temporarily_allow_network() -> Iterator[None]:
|
|
125
|
+
"""
|
|
126
|
+
Temporarily allow network access inside the context block.
|
|
127
|
+
|
|
128
|
+
Restores the previous policy when exiting the context.
|
|
129
|
+
"""
|
|
130
|
+
was_installed = _GUARD.installed
|
|
131
|
+
if was_installed:
|
|
132
|
+
_GUARD.restore()
|
|
133
|
+
try:
|
|
134
|
+
yield
|
|
135
|
+
finally:
|
|
136
|
+
if was_installed:
|
|
137
|
+
_GUARD.install()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@contextlib.contextmanager
|
|
141
|
+
def secure_tempdir(
|
|
142
|
+
prefix: str = "invarlock-", base_dir: str | os.PathLike[str] | None = None
|
|
143
|
+
) -> Iterator[Path]:
|
|
144
|
+
"""
|
|
145
|
+
Create a temporary directory with 0o700 permissions that is removed on exit.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
prefix: Directory name prefix.
|
|
149
|
+
base_dir: Optional base directory.
|
|
150
|
+
|
|
151
|
+
Yields:
|
|
152
|
+
Path to the secure temporary directory.
|
|
153
|
+
"""
|
|
154
|
+
path = Path(tempfile.mkdtemp(prefix=prefix, dir=base_dir))
|
|
155
|
+
os.chmod(path, 0o700)
|
|
156
|
+
try:
|
|
157
|
+
yield path
|
|
158
|
+
finally:
|
|
159
|
+
shutil.rmtree(path, ignore_errors=True)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def is_secure_path(path: Path) -> bool:
|
|
163
|
+
"""
|
|
164
|
+
Check whether a path has secure (0o700) permissions.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
path: Path to validate.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
True if the path exists and has 0o700 permissions, False otherwise.
|
|
171
|
+
"""
|
|
172
|
+
try:
|
|
173
|
+
mode = path.stat().st_mode
|
|
174
|
+
except FileNotFoundError:
|
|
175
|
+
return False
|
|
176
|
+
return stat.S_IMODE(mode) == 0o700
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Sparsity Utilities
|
|
3
|
+
========================
|
|
4
|
+
|
|
5
|
+
Utilities for working with sparse models and sparsity patterns.
|
|
6
|
+
Helper functions for analyzing and manipulating model sparsity.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
TORCH_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
TORCH_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"calculate_sparsity",
|
|
24
|
+
"get_zero_mask",
|
|
25
|
+
"apply_mask",
|
|
26
|
+
"count_parameters",
|
|
27
|
+
"get_sparsity_stats",
|
|
28
|
+
"create_structured_mask",
|
|
29
|
+
"validate_sparsity_target",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def calculate_sparsity(tensor) -> float:
|
|
34
|
+
"""
|
|
35
|
+
Calculate sparsity ratio of a tensor.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
tensor: PyTorch tensor
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Sparsity ratio (fraction of zero elements)
|
|
42
|
+
"""
|
|
43
|
+
if not TORCH_AVAILABLE:
|
|
44
|
+
return 0.0
|
|
45
|
+
|
|
46
|
+
if not isinstance(tensor, torch.Tensor):
|
|
47
|
+
return 0.0
|
|
48
|
+
|
|
49
|
+
total_elements = tensor.numel()
|
|
50
|
+
if total_elements == 0:
|
|
51
|
+
return 0.0
|
|
52
|
+
|
|
53
|
+
zero_elements = (tensor == 0).sum().item()
|
|
54
|
+
return zero_elements / total_elements
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_zero_mask(tensor, threshold: float = 1e-8):
|
|
58
|
+
"""
|
|
59
|
+
Get boolean mask of effectively zero elements.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
tensor: PyTorch tensor
|
|
63
|
+
threshold: Threshold below which values are considered zero
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Boolean mask where True indicates zero/near-zero elements
|
|
67
|
+
"""
|
|
68
|
+
if not TORCH_AVAILABLE:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
if not isinstance(tensor, torch.Tensor):
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
return torch.abs(tensor) <= threshold
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def apply_mask(tensor, mask, fill_value: float = 0.0):
|
|
78
|
+
"""
|
|
79
|
+
Apply a boolean mask to zero out elements.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
tensor: PyTorch tensor to modify
|
|
83
|
+
mask: Boolean mask (True = zero out)
|
|
84
|
+
fill_value: Value to fill masked positions
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Modified tensor
|
|
88
|
+
"""
|
|
89
|
+
if not TORCH_AVAILABLE:
|
|
90
|
+
return tensor
|
|
91
|
+
|
|
92
|
+
if not isinstance(tensor, torch.Tensor) or not isinstance(mask, torch.Tensor):
|
|
93
|
+
return tensor
|
|
94
|
+
|
|
95
|
+
result = tensor.clone()
|
|
96
|
+
result[mask] = fill_value
|
|
97
|
+
return result
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def count_parameters(module, only_trainable: bool = True) -> dict[str, int]:
|
|
101
|
+
"""
|
|
102
|
+
Count parameters in a module.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
module: PyTorch module
|
|
106
|
+
only_trainable: Count only trainable parameters
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Dictionary with parameter counts
|
|
110
|
+
"""
|
|
111
|
+
if not TORCH_AVAILABLE or not isinstance(module, nn.Module):
|
|
112
|
+
return {"total": 0, "trainable": 0, "non_trainable": 0}
|
|
113
|
+
|
|
114
|
+
total_params = 0
|
|
115
|
+
trainable_params = 0
|
|
116
|
+
non_trainable_params = 0
|
|
117
|
+
|
|
118
|
+
for param in module.parameters():
|
|
119
|
+
param_count = param.numel()
|
|
120
|
+
total_params += param_count
|
|
121
|
+
|
|
122
|
+
if param.requires_grad:
|
|
123
|
+
trainable_params += param_count
|
|
124
|
+
else:
|
|
125
|
+
non_trainable_params += param_count
|
|
126
|
+
|
|
127
|
+
return {
|
|
128
|
+
"total": total_params,
|
|
129
|
+
"trainable": trainable_params,
|
|
130
|
+
"non_trainable": non_trainable_params,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_sparsity_stats(module) -> dict[str, Any]:
|
|
135
|
+
"""
|
|
136
|
+
Get comprehensive sparsity statistics for a module.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
module: PyTorch module
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Dictionary with sparsity statistics
|
|
143
|
+
"""
|
|
144
|
+
if not TORCH_AVAILABLE or not isinstance(module, nn.Module):
|
|
145
|
+
return {}
|
|
146
|
+
|
|
147
|
+
stats = {
|
|
148
|
+
"overall_sparsity": 0.0,
|
|
149
|
+
"layer_sparsities": {},
|
|
150
|
+
"parameter_counts": count_parameters(module),
|
|
151
|
+
"sparse_layers": 0,
|
|
152
|
+
"dense_layers": 0,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
total_elements = 0
|
|
156
|
+
total_zeros = 0
|
|
157
|
+
|
|
158
|
+
for name, param in module.named_parameters():
|
|
159
|
+
if param.numel() == 0:
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
layer_sparsity = calculate_sparsity(param)
|
|
163
|
+
stats["layer_sparsities"][name] = layer_sparsity
|
|
164
|
+
|
|
165
|
+
# Count for overall statistics
|
|
166
|
+
elements = param.numel()
|
|
167
|
+
zeros = (param == 0).sum().item()
|
|
168
|
+
|
|
169
|
+
total_elements += elements
|
|
170
|
+
total_zeros += zeros
|
|
171
|
+
|
|
172
|
+
# Classify layer as sparse or dense
|
|
173
|
+
if layer_sparsity > 0.1: # 10% threshold for "sparse"
|
|
174
|
+
stats["sparse_layers"] += 1
|
|
175
|
+
else:
|
|
176
|
+
stats["dense_layers"] += 1
|
|
177
|
+
|
|
178
|
+
# Calculate overall sparsity
|
|
179
|
+
if total_elements > 0:
|
|
180
|
+
stats["overall_sparsity"] = total_zeros / total_elements
|
|
181
|
+
|
|
182
|
+
return stats
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def create_structured_mask(
|
|
186
|
+
shape: tuple[int, ...], pattern: str = "block", block_size: int = 4
|
|
187
|
+
) -> torch.Tensor | None:
|
|
188
|
+
"""
|
|
189
|
+
Create block/group sparsity masks.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
shape: Shape of the tensor
|
|
193
|
+
pattern: Sparsity pattern ('block', 'column', 'row')
|
|
194
|
+
block_size: Size of blocks for block sparsity
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Boolean mask tensor
|
|
198
|
+
"""
|
|
199
|
+
if not TORCH_AVAILABLE:
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
mask = torch.zeros(shape, dtype=torch.bool)
|
|
203
|
+
|
|
204
|
+
if pattern == "block" and len(shape) >= 2:
|
|
205
|
+
# Block sparsity pattern
|
|
206
|
+
for i in range(0, shape[0], block_size):
|
|
207
|
+
for j in range(0, shape[1], block_size):
|
|
208
|
+
# Create checkerboard pattern of blocks
|
|
209
|
+
if (i // block_size + j // block_size) % 2 == 0:
|
|
210
|
+
end_i = min(i + block_size, shape[0])
|
|
211
|
+
end_j = min(j + block_size, shape[1])
|
|
212
|
+
mask[i:end_i, j:end_j] = True
|
|
213
|
+
|
|
214
|
+
elif pattern == "column" and len(shape) >= 2:
|
|
215
|
+
# Column sparsity - zero out every other column
|
|
216
|
+
mask[:, ::2] = True
|
|
217
|
+
|
|
218
|
+
elif pattern == "row" and len(shape) >= 2:
|
|
219
|
+
# Row sparsity - zero out every other row
|
|
220
|
+
mask[::2, :] = True
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
warnings.warn(
|
|
224
|
+
f"Unsupported sparsity pattern '{pattern}' for shape {shape}", stacklevel=2
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return mask
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def validate_sparsity_target(target_sparsity: Any) -> bool:
|
|
231
|
+
"""
|
|
232
|
+
Validate sparsity target value.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
target_sparsity: Target sparsity ratio (0.0 to 1.0)
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
True if valid
|
|
239
|
+
"""
|
|
240
|
+
if not isinstance(target_sparsity, int | float):
|
|
241
|
+
return False
|
|
242
|
+
return 0.0 <= target_sparsity <= 1.0
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_magnitude_mask(tensor, sparsity_ratio: float) -> torch.Tensor | None:
|
|
246
|
+
"""
|
|
247
|
+
Create magnitude-based sparsity mask.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
tensor: PyTorch tensor
|
|
251
|
+
sparsity_ratio: Target sparsity fraction
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Boolean mask where True indicates weights to zero
|
|
255
|
+
"""
|
|
256
|
+
if not TORCH_AVAILABLE or not isinstance(tensor, torch.Tensor):
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
if not validate_sparsity_target(sparsity_ratio):
|
|
260
|
+
raise ValueError(f"Invalid sparsity ratio: {sparsity_ratio}")
|
|
261
|
+
|
|
262
|
+
# Flatten tensor for magnitude ranking
|
|
263
|
+
flat_tensor = tensor.view(-1)
|
|
264
|
+
|
|
265
|
+
# Find threshold for zeroing
|
|
266
|
+
num_to_zero = int(sparsity_ratio * flat_tensor.numel())
|
|
267
|
+
if num_to_zero == 0:
|
|
268
|
+
return torch.zeros_like(tensor, dtype=torch.bool)
|
|
269
|
+
|
|
270
|
+
# Get magnitude and find threshold
|
|
271
|
+
magnitudes = torch.abs(flat_tensor)
|
|
272
|
+
threshold_value = torch.kthvalue(magnitudes, num_to_zero).values
|
|
273
|
+
|
|
274
|
+
# Create mask
|
|
275
|
+
mask = torch.abs(tensor) <= threshold_value
|
|
276
|
+
|
|
277
|
+
return mask
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def analyze_sparsity_impact(original_tensor, edited_tensor) -> dict[str, Any]:
|
|
281
|
+
"""
|
|
282
|
+
Analyze the impact of applied sparsity on a tensor.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
original_tensor: Original tensor before changes
|
|
286
|
+
edited_tensor: Tensor after applying sparsity
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Dictionary with impact analysis
|
|
290
|
+
"""
|
|
291
|
+
if not TORCH_AVAILABLE:
|
|
292
|
+
return {}
|
|
293
|
+
|
|
294
|
+
if not (
|
|
295
|
+
isinstance(original_tensor, torch.Tensor)
|
|
296
|
+
and isinstance(edited_tensor, torch.Tensor)
|
|
297
|
+
):
|
|
298
|
+
return {}
|
|
299
|
+
|
|
300
|
+
# Calculate basic statistics
|
|
301
|
+
original_sparsity = calculate_sparsity(original_tensor)
|
|
302
|
+
final_sparsity = calculate_sparsity(edited_tensor)
|
|
303
|
+
|
|
304
|
+
# Calculate magnitude changes
|
|
305
|
+
magnitude_change = torch.norm(edited_tensor - original_tensor).item()
|
|
306
|
+
relative_change = (
|
|
307
|
+
magnitude_change / torch.norm(original_tensor).item()
|
|
308
|
+
if torch.norm(original_tensor) > 0
|
|
309
|
+
else 0.0
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
analysis = {
|
|
313
|
+
"original_sparsity": original_sparsity,
|
|
314
|
+
"final_sparsity": final_sparsity,
|
|
315
|
+
"sparsity_increase": final_sparsity - original_sparsity,
|
|
316
|
+
"magnitude_change": magnitude_change,
|
|
317
|
+
"relative_change": relative_change,
|
|
318
|
+
"compression_ratio": final_sparsity / original_sparsity
|
|
319
|
+
if original_sparsity > 0
|
|
320
|
+
else float("inf"),
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
return analysis
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Utilities
|
|
3
|
+
===============
|
|
4
|
+
|
|
5
|
+
Common utility functions used across InvarLock modules.
|
|
6
|
+
|
|
7
|
+
This package also exposes submodules such as `invarlock.utils.digest` for
|
|
8
|
+
hashing and provenance utilities.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import psutil
|
|
16
|
+
|
|
17
|
+
try: # Torch is optional; utils can be imported without it.
|
|
18
|
+
import torch # type: ignore[import]
|
|
19
|
+
except Exception: # pragma: no cover - exercised when torch is missing
|
|
20
|
+
torch = None # type: ignore[assignment]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def extract_input_ids(
|
|
24
|
+
batch: Any, device: str | None = None, strict: bool = False
|
|
25
|
+
) -> torch.Tensor:
|
|
26
|
+
"""
|
|
27
|
+
Extract input_ids from various batch formats.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
batch: Input batch (tensor, dict, or other format)
|
|
31
|
+
device: Target device for tensor
|
|
32
|
+
strict: Whether to raise errors on format issues
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Extracted input_ids tensor
|
|
36
|
+
"""
|
|
37
|
+
if isinstance(batch, torch.Tensor):
|
|
38
|
+
input_ids = batch
|
|
39
|
+
elif isinstance(batch, dict):
|
|
40
|
+
if "input_ids" in batch:
|
|
41
|
+
input_ids = batch["input_ids"]
|
|
42
|
+
elif "inputs" in batch:
|
|
43
|
+
input_ids = batch["inputs"]
|
|
44
|
+
else:
|
|
45
|
+
if strict:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Dict batch missing 'input_ids' or 'inputs' keys: {list(batch.keys())}"
|
|
48
|
+
)
|
|
49
|
+
# Try first tensor value
|
|
50
|
+
for value in batch.values():
|
|
51
|
+
if isinstance(value, torch.Tensor):
|
|
52
|
+
input_ids = value
|
|
53
|
+
break
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError("No tensor found in batch dict")
|
|
56
|
+
elif hasattr(batch, "input_ids"):
|
|
57
|
+
input_ids = batch.input_ids
|
|
58
|
+
else:
|
|
59
|
+
if strict:
|
|
60
|
+
raise ValueError(f"Unsupported batch format: {type(batch)}")
|
|
61
|
+
# Try to convert directly
|
|
62
|
+
input_ids = torch.tensor(batch)
|
|
63
|
+
|
|
64
|
+
# Move to device if specified
|
|
65
|
+
if device is not None:
|
|
66
|
+
input_ids = input_ids.to(device)
|
|
67
|
+
|
|
68
|
+
return input_ids
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_model_device(model: torch.nn.Module) -> torch.device:
|
|
72
|
+
"""Get the device of a model."""
|
|
73
|
+
return next(model.parameters()).device
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def ensure_tensor(data: Any, device: torch.device | None = None) -> torch.Tensor:
|
|
77
|
+
"""Ensure data is a tensor on the correct device."""
|
|
78
|
+
if not isinstance(data, torch.Tensor):
|
|
79
|
+
data = torch.tensor(data)
|
|
80
|
+
|
|
81
|
+
if device is not None:
|
|
82
|
+
data = data.to(device)
|
|
83
|
+
|
|
84
|
+
return data
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
|
|
88
|
+
"""Safely divide two numbers, returning default if denominator is zero."""
|
|
89
|
+
if abs(denominator) < 1e-12:
|
|
90
|
+
return default
|
|
91
|
+
return numerator / denominator
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def dict_to_device(
|
|
95
|
+
data: dict[str, torch.Tensor], device: torch.device
|
|
96
|
+
) -> dict[str, torch.Tensor]:
|
|
97
|
+
"""Move all tensors in a dictionary to the specified device."""
|
|
98
|
+
return {
|
|
99
|
+
key: value.to(device) if isinstance(value, torch.Tensor) else value
|
|
100
|
+
for key, value in data.items()
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def format_number(num: float, precision: int = 3) -> str:
|
|
105
|
+
"""Format a number for display."""
|
|
106
|
+
if abs(num) < 1e-3:
|
|
107
|
+
return f"{num:.2e}"
|
|
108
|
+
elif abs(num) < 1:
|
|
109
|
+
return f"{num:.{precision + 1}f}"
|
|
110
|
+
else:
|
|
111
|
+
return f"{num:.{precision}f}"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_memory_usage() -> dict[str, float]:
|
|
115
|
+
"""Get current memory usage in MB."""
|
|
116
|
+
import gc
|
|
117
|
+
|
|
118
|
+
# Force garbage collection
|
|
119
|
+
gc.collect()
|
|
120
|
+
|
|
121
|
+
# Get process memory
|
|
122
|
+
process = psutil.Process()
|
|
123
|
+
memory_info = process.memory_info()
|
|
124
|
+
|
|
125
|
+
result = {
|
|
126
|
+
"rss_mb": memory_info.rss / 1024 / 1024, # Resident Set Size
|
|
127
|
+
"vms_mb": memory_info.vms / 1024 / 1024, # Virtual Memory Size
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
# Add CUDA memory if available
|
|
131
|
+
try:
|
|
132
|
+
if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
133
|
+
result["cuda_allocated_mb"] = torch.cuda.memory_allocated() / 1024 / 1024
|
|
134
|
+
result["cuda_reserved_mb"] = torch.cuda.memory_reserved() / 1024 / 1024
|
|
135
|
+
except Exception:
|
|
136
|
+
# If torch is unavailable or querying CUDA fails, fall back to CPU-only stats.
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
__all__ = [
|
|
143
|
+
"extract_input_ids",
|
|
144
|
+
"get_model_device",
|
|
145
|
+
"ensure_tensor",
|
|
146
|
+
"safe_divide",
|
|
147
|
+
"dict_to_device",
|
|
148
|
+
"format_number",
|
|
149
|
+
"get_memory_usage",
|
|
150
|
+
]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Protocol
|
|
6
|
+
|
|
7
|
+
_ENC = "utf-8"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _HashLike(Protocol):
|
|
11
|
+
def update(self, data: bytes, /) -> None: ...
|
|
12
|
+
def hexdigest(self, /) -> str: ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _h() -> _HashLike:
|
|
16
|
+
return hashlib.blake2s(digest_size=32)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def hash_bytes(b: bytes, *, salt: bytes | None = None) -> str:
|
|
20
|
+
h = _h()
|
|
21
|
+
if salt:
|
|
22
|
+
h.update(salt)
|
|
23
|
+
h.update(b)
|
|
24
|
+
return h.hexdigest()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def hash_json(obj: Any, *, salt: str | None = None) -> str:
|
|
28
|
+
s = json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
|
|
29
|
+
return hash_bytes(s.encode(_ENC), salt=salt.encode(_ENC) if salt else None)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def hash_int_array(arr, *, salt: str | None = None, byteorder: str = "little") -> str:
|
|
33
|
+
# Accept numpy arrays to avoid extra copies
|
|
34
|
+
try:
|
|
35
|
+
import numpy as _np
|
|
36
|
+
except Exception: # pragma: no cover - numpy always present in tests
|
|
37
|
+
# Fallback: best-effort conversion
|
|
38
|
+
b = bytes(int(x) & 0xFFFFFFFF for x in arr)
|
|
39
|
+
return hash_bytes(b, salt=salt.encode(_ENC) if salt else None)
|
|
40
|
+
|
|
41
|
+
a = _np.asarray(arr, dtype=_np.int32, order="C")
|
|
42
|
+
return hash_bytes(a.tobytes(order="C"), salt=salt.encode(_ENC) if salt else None)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = ["hash_bytes", "hash_json", "hash_int_array"]
|