invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/security.py ADDED
@@ -0,0 +1,176 @@
1
+ """
2
+ InvarLock Security Utilities
3
+ ========================
4
+
5
+ Runtime hardening helpers used by the CLI and automation surfaces.
6
+
7
+ - Network guard: disables outbound socket connections unless explicitly allowed.
8
+ - Secure temporary directory helper ensuring 0o700 permissions and cleanup.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import contextlib
14
+ import os
15
+ import shutil
16
+ import socket
17
+ import stat
18
+ import tempfile
19
+ from collections.abc import Callable, Iterator
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ __all__ = [
24
+ "NetworkGuard",
25
+ "enforce_network_policy",
26
+ "enforce_default_security",
27
+ "temporarily_allow_network",
28
+ "network_policy_allows",
29
+ "secure_tempdir",
30
+ "is_secure_path",
31
+ ]
32
+
33
+ _NETWORK_DISABLED_ERROR = RuntimeError(
34
+ "Network access disabled by InvarLock security policy. "
35
+ "Set INVARLOCK_ALLOW_NETWORK=1 if connectivity is required."
36
+ )
37
+
38
+
39
+ class NetworkGuard:
40
+ """Installable network guard that blocks outbound socket connections."""
41
+
42
+ def __init__(self) -> None:
43
+ self._installed = False
44
+ self._original_socket_cls: type[socket.socket] | None = None
45
+ self._original_create_connection: Callable[..., socket.socket] | None = None
46
+
47
+ @property
48
+ def installed(self) -> bool:
49
+ """Whether the guard is currently blocking network access."""
50
+ return self._installed
51
+
52
+ def install(self) -> None:
53
+ """Install the guard if not already active."""
54
+ if self._installed:
55
+ return
56
+
57
+ self._original_socket_cls = socket.socket
58
+ self._original_create_connection = socket.create_connection
59
+
60
+ guard_error = _NETWORK_DISABLED_ERROR
61
+
62
+ class GuardedSocket(socket.socket):
63
+ """Socket subclass that blocks connect calls."""
64
+
65
+ def connect(self_inner, address: Any) -> None:
66
+ raise guard_error
67
+
68
+ def guarded_create_connection(
69
+ address: Any,
70
+ timeout: float | None = None,
71
+ source_address: Any | None = None,
72
+ ) -> socket.socket:
73
+ raise guard_error
74
+
75
+ setattr(socket, "socket", GuardedSocket) # noqa: B010
76
+ setattr(socket, "create_connection", guarded_create_connection) # noqa: B010
77
+ self._installed = True
78
+
79
+ def restore(self) -> None:
80
+ """Restore the original socket implementations."""
81
+ if not self._installed:
82
+ return
83
+
84
+ if self._original_socket_cls is not None:
85
+ setattr(socket, "socket", self._original_socket_cls) # noqa: B010
86
+ if self._original_create_connection is not None:
87
+ setattr(socket, "create_connection", self._original_create_connection) # noqa: B010
88
+ self._installed = False
89
+
90
+
91
+ _GUARD = NetworkGuard()
92
+
93
+
94
+ def enforce_network_policy(allow: bool) -> None:
95
+ """
96
+ Apply the global network policy.
97
+
98
+ Args:
99
+ allow: When True the guard is removed, otherwise network access is blocked.
100
+ """
101
+ if allow:
102
+ _GUARD.restore()
103
+ else:
104
+ _GUARD.install()
105
+
106
+
107
+ def network_policy_allows() -> bool:
108
+ """Return True if outbound connections are currently permitted."""
109
+ return not _GUARD.installed
110
+
111
+
112
+ def enforce_default_security() -> None:
113
+ """
114
+ Enforce default runtime security posture.
115
+
116
+ Network access is denied unless INVARLOCK_ALLOW_NETWORK is set to a truthy value.
117
+ """
118
+ allow_env = os.environ.get("INVARLOCK_ALLOW_NETWORK", "")
119
+ allow_network = allow_env.strip().lower() in {"1", "true", "yes", "on"}
120
+ enforce_network_policy(allow_network)
121
+
122
+
123
+ @contextlib.contextmanager
124
+ def temporarily_allow_network() -> Iterator[None]:
125
+ """
126
+ Temporarily allow network access inside the context block.
127
+
128
+ Restores the previous policy when exiting the context.
129
+ """
130
+ was_installed = _GUARD.installed
131
+ if was_installed:
132
+ _GUARD.restore()
133
+ try:
134
+ yield
135
+ finally:
136
+ if was_installed:
137
+ _GUARD.install()
138
+
139
+
140
+ @contextlib.contextmanager
141
+ def secure_tempdir(
142
+ prefix: str = "invarlock-", base_dir: str | os.PathLike[str] | None = None
143
+ ) -> Iterator[Path]:
144
+ """
145
+ Create a temporary directory with 0o700 permissions that is removed on exit.
146
+
147
+ Args:
148
+ prefix: Directory name prefix.
149
+ base_dir: Optional base directory.
150
+
151
+ Yields:
152
+ Path to the secure temporary directory.
153
+ """
154
+ path = Path(tempfile.mkdtemp(prefix=prefix, dir=base_dir))
155
+ os.chmod(path, 0o700)
156
+ try:
157
+ yield path
158
+ finally:
159
+ shutil.rmtree(path, ignore_errors=True)
160
+
161
+
162
+ def is_secure_path(path: Path) -> bool:
163
+ """
164
+ Check whether a path has secure (0o700) permissions.
165
+
166
+ Args:
167
+ path: Path to validate.
168
+
169
+ Returns:
170
+ True if the path exists and has 0o700 permissions, False otherwise.
171
+ """
172
+ try:
173
+ mode = path.stat().st_mode
174
+ except FileNotFoundError:
175
+ return False
176
+ return stat.S_IMODE(mode) == 0o700
@@ -0,0 +1,323 @@
1
+ """
2
+ InvarLock Sparsity Utilities
3
+ ========================
4
+
5
+ Utilities for working with sparse models and sparsity patterns.
6
+ Helper functions for analyzing and manipulating model sparsity.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import warnings
12
+ from typing import Any
13
+
14
+ try:
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ TORCH_AVAILABLE = True
19
+ except ImportError:
20
+ TORCH_AVAILABLE = False
21
+
22
+ __all__ = [
23
+ "calculate_sparsity",
24
+ "get_zero_mask",
25
+ "apply_mask",
26
+ "count_parameters",
27
+ "get_sparsity_stats",
28
+ "create_structured_mask",
29
+ "validate_sparsity_target",
30
+ ]
31
+
32
+
33
+ def calculate_sparsity(tensor) -> float:
34
+ """
35
+ Calculate sparsity ratio of a tensor.
36
+
37
+ Args:
38
+ tensor: PyTorch tensor
39
+
40
+ Returns:
41
+ Sparsity ratio (fraction of zero elements)
42
+ """
43
+ if not TORCH_AVAILABLE:
44
+ return 0.0
45
+
46
+ if not isinstance(tensor, torch.Tensor):
47
+ return 0.0
48
+
49
+ total_elements = tensor.numel()
50
+ if total_elements == 0:
51
+ return 0.0
52
+
53
+ zero_elements = (tensor == 0).sum().item()
54
+ return zero_elements / total_elements
55
+
56
+
57
+ def get_zero_mask(tensor, threshold: float = 1e-8):
58
+ """
59
+ Get boolean mask of effectively zero elements.
60
+
61
+ Args:
62
+ tensor: PyTorch tensor
63
+ threshold: Threshold below which values are considered zero
64
+
65
+ Returns:
66
+ Boolean mask where True indicates zero/near-zero elements
67
+ """
68
+ if not TORCH_AVAILABLE:
69
+ return None
70
+
71
+ if not isinstance(tensor, torch.Tensor):
72
+ return None
73
+
74
+ return torch.abs(tensor) <= threshold
75
+
76
+
77
+ def apply_mask(tensor, mask, fill_value: float = 0.0):
78
+ """
79
+ Apply a boolean mask to zero out elements.
80
+
81
+ Args:
82
+ tensor: PyTorch tensor to modify
83
+ mask: Boolean mask (True = zero out)
84
+ fill_value: Value to fill masked positions
85
+
86
+ Returns:
87
+ Modified tensor
88
+ """
89
+ if not TORCH_AVAILABLE:
90
+ return tensor
91
+
92
+ if not isinstance(tensor, torch.Tensor) or not isinstance(mask, torch.Tensor):
93
+ return tensor
94
+
95
+ result = tensor.clone()
96
+ result[mask] = fill_value
97
+ return result
98
+
99
+
100
+ def count_parameters(module, only_trainable: bool = True) -> dict[str, int]:
101
+ """
102
+ Count parameters in a module.
103
+
104
+ Args:
105
+ module: PyTorch module
106
+ only_trainable: Count only trainable parameters
107
+
108
+ Returns:
109
+ Dictionary with parameter counts
110
+ """
111
+ if not TORCH_AVAILABLE or not isinstance(module, nn.Module):
112
+ return {"total": 0, "trainable": 0, "non_trainable": 0}
113
+
114
+ total_params = 0
115
+ trainable_params = 0
116
+ non_trainable_params = 0
117
+
118
+ for param in module.parameters():
119
+ param_count = param.numel()
120
+ total_params += param_count
121
+
122
+ if param.requires_grad:
123
+ trainable_params += param_count
124
+ else:
125
+ non_trainable_params += param_count
126
+
127
+ return {
128
+ "total": total_params,
129
+ "trainable": trainable_params,
130
+ "non_trainable": non_trainable_params,
131
+ }
132
+
133
+
134
+ def get_sparsity_stats(module) -> dict[str, Any]:
135
+ """
136
+ Get comprehensive sparsity statistics for a module.
137
+
138
+ Args:
139
+ module: PyTorch module
140
+
141
+ Returns:
142
+ Dictionary with sparsity statistics
143
+ """
144
+ if not TORCH_AVAILABLE or not isinstance(module, nn.Module):
145
+ return {}
146
+
147
+ stats = {
148
+ "overall_sparsity": 0.0,
149
+ "layer_sparsities": {},
150
+ "parameter_counts": count_parameters(module),
151
+ "sparse_layers": 0,
152
+ "dense_layers": 0,
153
+ }
154
+
155
+ total_elements = 0
156
+ total_zeros = 0
157
+
158
+ for name, param in module.named_parameters():
159
+ if param.numel() == 0:
160
+ continue
161
+
162
+ layer_sparsity = calculate_sparsity(param)
163
+ stats["layer_sparsities"][name] = layer_sparsity
164
+
165
+ # Count for overall statistics
166
+ elements = param.numel()
167
+ zeros = (param == 0).sum().item()
168
+
169
+ total_elements += elements
170
+ total_zeros += zeros
171
+
172
+ # Classify layer as sparse or dense
173
+ if layer_sparsity > 0.1: # 10% threshold for "sparse"
174
+ stats["sparse_layers"] += 1
175
+ else:
176
+ stats["dense_layers"] += 1
177
+
178
+ # Calculate overall sparsity
179
+ if total_elements > 0:
180
+ stats["overall_sparsity"] = total_zeros / total_elements
181
+
182
+ return stats
183
+
184
+
185
+ def create_structured_mask(
186
+ shape: tuple[int, ...], pattern: str = "block", block_size: int = 4
187
+ ) -> torch.Tensor | None:
188
+ """
189
+ Create block/group sparsity masks.
190
+
191
+ Args:
192
+ shape: Shape of the tensor
193
+ pattern: Sparsity pattern ('block', 'column', 'row')
194
+ block_size: Size of blocks for block sparsity
195
+
196
+ Returns:
197
+ Boolean mask tensor
198
+ """
199
+ if not TORCH_AVAILABLE:
200
+ return None
201
+
202
+ mask = torch.zeros(shape, dtype=torch.bool)
203
+
204
+ if pattern == "block" and len(shape) >= 2:
205
+ # Block sparsity pattern
206
+ for i in range(0, shape[0], block_size):
207
+ for j in range(0, shape[1], block_size):
208
+ # Create checkerboard pattern of blocks
209
+ if (i // block_size + j // block_size) % 2 == 0:
210
+ end_i = min(i + block_size, shape[0])
211
+ end_j = min(j + block_size, shape[1])
212
+ mask[i:end_i, j:end_j] = True
213
+
214
+ elif pattern == "column" and len(shape) >= 2:
215
+ # Column sparsity - zero out every other column
216
+ mask[:, ::2] = True
217
+
218
+ elif pattern == "row" and len(shape) >= 2:
219
+ # Row sparsity - zero out every other row
220
+ mask[::2, :] = True
221
+
222
+ else:
223
+ warnings.warn(
224
+ f"Unsupported sparsity pattern '{pattern}' for shape {shape}", stacklevel=2
225
+ )
226
+
227
+ return mask
228
+
229
+
230
+ def validate_sparsity_target(target_sparsity: Any) -> bool:
231
+ """
232
+ Validate sparsity target value.
233
+
234
+ Args:
235
+ target_sparsity: Target sparsity ratio (0.0 to 1.0)
236
+
237
+ Returns:
238
+ True if valid
239
+ """
240
+ if not isinstance(target_sparsity, int | float):
241
+ return False
242
+ return 0.0 <= target_sparsity <= 1.0
243
+
244
+
245
+ def get_magnitude_mask(tensor, sparsity_ratio: float) -> torch.Tensor | None:
246
+ """
247
+ Create magnitude-based sparsity mask.
248
+
249
+ Args:
250
+ tensor: PyTorch tensor
251
+ sparsity_ratio: Target sparsity fraction
252
+
253
+ Returns:
254
+ Boolean mask where True indicates weights to zero
255
+ """
256
+ if not TORCH_AVAILABLE or not isinstance(tensor, torch.Tensor):
257
+ return None
258
+
259
+ if not validate_sparsity_target(sparsity_ratio):
260
+ raise ValueError(f"Invalid sparsity ratio: {sparsity_ratio}")
261
+
262
+ # Flatten tensor for magnitude ranking
263
+ flat_tensor = tensor.view(-1)
264
+
265
+ # Find threshold for zeroing
266
+ num_to_zero = int(sparsity_ratio * flat_tensor.numel())
267
+ if num_to_zero == 0:
268
+ return torch.zeros_like(tensor, dtype=torch.bool)
269
+
270
+ # Get magnitude and find threshold
271
+ magnitudes = torch.abs(flat_tensor)
272
+ threshold_value = torch.kthvalue(magnitudes, num_to_zero).values
273
+
274
+ # Create mask
275
+ mask = torch.abs(tensor) <= threshold_value
276
+
277
+ return mask
278
+
279
+
280
+ def analyze_sparsity_impact(original_tensor, edited_tensor) -> dict[str, Any]:
281
+ """
282
+ Analyze the impact of applied sparsity on a tensor.
283
+
284
+ Args:
285
+ original_tensor: Original tensor before changes
286
+ edited_tensor: Tensor after applying sparsity
287
+
288
+ Returns:
289
+ Dictionary with impact analysis
290
+ """
291
+ if not TORCH_AVAILABLE:
292
+ return {}
293
+
294
+ if not (
295
+ isinstance(original_tensor, torch.Tensor)
296
+ and isinstance(edited_tensor, torch.Tensor)
297
+ ):
298
+ return {}
299
+
300
+ # Calculate basic statistics
301
+ original_sparsity = calculate_sparsity(original_tensor)
302
+ final_sparsity = calculate_sparsity(edited_tensor)
303
+
304
+ # Calculate magnitude changes
305
+ magnitude_change = torch.norm(edited_tensor - original_tensor).item()
306
+ relative_change = (
307
+ magnitude_change / torch.norm(original_tensor).item()
308
+ if torch.norm(original_tensor) > 0
309
+ else 0.0
310
+ )
311
+
312
+ analysis = {
313
+ "original_sparsity": original_sparsity,
314
+ "final_sparsity": final_sparsity,
315
+ "sparsity_increase": final_sparsity - original_sparsity,
316
+ "magnitude_change": magnitude_change,
317
+ "relative_change": relative_change,
318
+ "compression_ratio": final_sparsity / original_sparsity
319
+ if original_sparsity > 0
320
+ else float("inf"),
321
+ }
322
+
323
+ return analysis
@@ -0,0 +1,150 @@
1
+ """
2
+ InvarLock Utilities
3
+ ===============
4
+
5
+ Common utility functions used across InvarLock modules.
6
+
7
+ This package also exposes submodules such as `invarlock.utils.digest` for
8
+ hashing and provenance utilities.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+ import psutil
16
+
17
+ try: # Torch is optional; utils can be imported without it.
18
+ import torch # type: ignore[import]
19
+ except Exception: # pragma: no cover - exercised when torch is missing
20
+ torch = None # type: ignore[assignment]
21
+
22
+
23
+ def extract_input_ids(
24
+ batch: Any, device: str | None = None, strict: bool = False
25
+ ) -> torch.Tensor:
26
+ """
27
+ Extract input_ids from various batch formats.
28
+
29
+ Args:
30
+ batch: Input batch (tensor, dict, or other format)
31
+ device: Target device for tensor
32
+ strict: Whether to raise errors on format issues
33
+
34
+ Returns:
35
+ Extracted input_ids tensor
36
+ """
37
+ if isinstance(batch, torch.Tensor):
38
+ input_ids = batch
39
+ elif isinstance(batch, dict):
40
+ if "input_ids" in batch:
41
+ input_ids = batch["input_ids"]
42
+ elif "inputs" in batch:
43
+ input_ids = batch["inputs"]
44
+ else:
45
+ if strict:
46
+ raise ValueError(
47
+ f"Dict batch missing 'input_ids' or 'inputs' keys: {list(batch.keys())}"
48
+ )
49
+ # Try first tensor value
50
+ for value in batch.values():
51
+ if isinstance(value, torch.Tensor):
52
+ input_ids = value
53
+ break
54
+ else:
55
+ raise ValueError("No tensor found in batch dict")
56
+ elif hasattr(batch, "input_ids"):
57
+ input_ids = batch.input_ids
58
+ else:
59
+ if strict:
60
+ raise ValueError(f"Unsupported batch format: {type(batch)}")
61
+ # Try to convert directly
62
+ input_ids = torch.tensor(batch)
63
+
64
+ # Move to device if specified
65
+ if device is not None:
66
+ input_ids = input_ids.to(device)
67
+
68
+ return input_ids
69
+
70
+
71
+ def get_model_device(model: torch.nn.Module) -> torch.device:
72
+ """Get the device of a model."""
73
+ return next(model.parameters()).device
74
+
75
+
76
+ def ensure_tensor(data: Any, device: torch.device | None = None) -> torch.Tensor:
77
+ """Ensure data is a tensor on the correct device."""
78
+ if not isinstance(data, torch.Tensor):
79
+ data = torch.tensor(data)
80
+
81
+ if device is not None:
82
+ data = data.to(device)
83
+
84
+ return data
85
+
86
+
87
+ def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
88
+ """Safely divide two numbers, returning default if denominator is zero."""
89
+ if abs(denominator) < 1e-12:
90
+ return default
91
+ return numerator / denominator
92
+
93
+
94
+ def dict_to_device(
95
+ data: dict[str, torch.Tensor], device: torch.device
96
+ ) -> dict[str, torch.Tensor]:
97
+ """Move all tensors in a dictionary to the specified device."""
98
+ return {
99
+ key: value.to(device) if isinstance(value, torch.Tensor) else value
100
+ for key, value in data.items()
101
+ }
102
+
103
+
104
+ def format_number(num: float, precision: int = 3) -> str:
105
+ """Format a number for display."""
106
+ if abs(num) < 1e-3:
107
+ return f"{num:.2e}"
108
+ elif abs(num) < 1:
109
+ return f"{num:.{precision + 1}f}"
110
+ else:
111
+ return f"{num:.{precision}f}"
112
+
113
+
114
+ def get_memory_usage() -> dict[str, float]:
115
+ """Get current memory usage in MB."""
116
+ import gc
117
+
118
+ # Force garbage collection
119
+ gc.collect()
120
+
121
+ # Get process memory
122
+ process = psutil.Process()
123
+ memory_info = process.memory_info()
124
+
125
+ result = {
126
+ "rss_mb": memory_info.rss / 1024 / 1024, # Resident Set Size
127
+ "vms_mb": memory_info.vms / 1024 / 1024, # Virtual Memory Size
128
+ }
129
+
130
+ # Add CUDA memory if available
131
+ try:
132
+ if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available():
133
+ result["cuda_allocated_mb"] = torch.cuda.memory_allocated() / 1024 / 1024
134
+ result["cuda_reserved_mb"] = torch.cuda.memory_reserved() / 1024 / 1024
135
+ except Exception:
136
+ # If torch is unavailable or querying CUDA fails, fall back to CPU-only stats.
137
+ pass
138
+
139
+ return result
140
+
141
+
142
+ __all__ = [
143
+ "extract_input_ids",
144
+ "get_model_device",
145
+ "ensure_tensor",
146
+ "safe_divide",
147
+ "dict_to_device",
148
+ "format_number",
149
+ "get_memory_usage",
150
+ ]
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ from typing import Any, Protocol
6
+
7
+ _ENC = "utf-8"
8
+
9
+
10
+ class _HashLike(Protocol):
11
+ def update(self, data: bytes, /) -> None: ...
12
+ def hexdigest(self, /) -> str: ...
13
+
14
+
15
+ def _h() -> _HashLike:
16
+ return hashlib.blake2s(digest_size=32)
17
+
18
+
19
+ def hash_bytes(b: bytes, *, salt: bytes | None = None) -> str:
20
+ h = _h()
21
+ if salt:
22
+ h.update(salt)
23
+ h.update(b)
24
+ return h.hexdigest()
25
+
26
+
27
+ def hash_json(obj: Any, *, salt: str | None = None) -> str:
28
+ s = json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
29
+ return hash_bytes(s.encode(_ENC), salt=salt.encode(_ENC) if salt else None)
30
+
31
+
32
+ def hash_int_array(arr, *, salt: str | None = None, byteorder: str = "little") -> str:
33
+ # Accept numpy arrays to avoid extra copies
34
+ try:
35
+ import numpy as _np
36
+ except Exception: # pragma: no cover - numpy always present in tests
37
+ # Fallback: best-effort conversion
38
+ b = bytes(int(x) & 0xFFFFFFFF for x in arr)
39
+ return hash_bytes(b, salt=salt.encode(_ENC) if salt else None)
40
+
41
+ a = _np.asarray(arr, dtype=_np.int32, order="C")
42
+ return hash_bytes(a.tobytes(order="C"), salt=salt.encode(_ENC) if salt else None)
43
+
44
+
45
+ __all__ = ["hash_bytes", "hash_json", "hash_int_array"]