space-ml-sim 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. space_ml_sim/__init__.py +3 -0
  2. space_ml_sim/compute/__init__.py +32 -0
  3. space_ml_sim/compute/checkpoint.py +78 -0
  4. space_ml_sim/compute/fault_injector.py +277 -0
  5. space_ml_sim/compute/inference_node.py +88 -0
  6. space_ml_sim/compute/onnx_adapter.py +246 -0
  7. space_ml_sim/compute/quantization.py +207 -0
  8. space_ml_sim/compute/scheduler.py +81 -0
  9. space_ml_sim/compute/tmr.py +254 -0
  10. space_ml_sim/compute/transformer_fault.py +340 -0
  11. space_ml_sim/core/__init__.py +30 -0
  12. space_ml_sim/core/clock.py +50 -0
  13. space_ml_sim/core/constellation.py +252 -0
  14. space_ml_sim/core/orbit.py +315 -0
  15. space_ml_sim/core/satellite.py +125 -0
  16. space_ml_sim/core/tle.py +266 -0
  17. space_ml_sim/environment/__init__.py +21 -0
  18. space_ml_sim/environment/comms.py +39 -0
  19. space_ml_sim/environment/power.py +30 -0
  20. space_ml_sim/environment/radiation.py +150 -0
  21. space_ml_sim/environment/thermal.py +39 -0
  22. space_ml_sim/environment/timeline.py +281 -0
  23. space_ml_sim/metrics/__init__.py +6 -0
  24. space_ml_sim/metrics/performance.py +43 -0
  25. space_ml_sim/metrics/reliability.py +48 -0
  26. space_ml_sim/models/__init__.py +27 -0
  27. space_ml_sim/models/chip_profiles.py +123 -0
  28. space_ml_sim/models/rad_profiles.py +13 -0
  29. space_ml_sim/py.typed +0 -0
  30. space_ml_sim/viz/__init__.py +11 -0
  31. space_ml_sim/viz/heatmap.py +147 -0
  32. space_ml_sim/viz/plots.py +166 -0
  33. space_ml_sim-0.3.0.dist-info/METADATA +242 -0
  34. space_ml_sim-0.3.0.dist-info/RECORD +36 -0
  35. space_ml_sim-0.3.0.dist-info/WHEEL +4 -0
  36. space_ml_sim-0.3.0.dist-info/licenses/LICENSE +661 -0
@@ -0,0 +1,3 @@
1
+ """space-ml-sim: Simulate AI inference on orbital satellite constellations under realistic space radiation."""
2
+
3
+ __version__ = "0.3.0"
@@ -0,0 +1,32 @@
1
+ """Compute module: fault injection, TMR, checkpointing, and scheduling."""
2
+
3
+ from space_ml_sim.compute.fault_injector import FaultInjector, FaultReport
4
+ from space_ml_sim.compute.transformer_fault import TransformerFaultInjector
5
+ from space_ml_sim.compute.tmr import TMRWrapper
6
+ from space_ml_sim.compute.checkpoint import CheckpointManager
7
+ from space_ml_sim.compute.scheduler import InferenceScheduler
8
+ from space_ml_sim.compute.quantization import (
9
+ quantize_model,
10
+ compare_quantization_resilience,
11
+ plot_quantization_comparison,
12
+ )
13
+
14
+ __all__ = [
15
+ "FaultInjector",
16
+ "FaultReport",
17
+ "TransformerFaultInjector",
18
+ "TMRWrapper",
19
+ "CheckpointManager",
20
+ "InferenceScheduler",
21
+ "quantize_model",
22
+ "compare_quantization_resilience",
23
+ "plot_quantization_comparison",
24
+ ]
25
+
26
+ # Optional ONNX support — only available when 'onnx' and 'onnxruntime' are installed.
27
+ try:
28
+ from space_ml_sim.compute.onnx_adapter import OnnxModel, load_onnx
29
+
30
+ __all__ += ["OnnxModel", "load_onnx"]
31
+ except ImportError:
32
+ pass # onnx extras not installed
@@ -0,0 +1,78 @@
1
+ """Model checkpointing for fault recovery.
2
+
3
+ Periodically saves model state so it can be restored after
4
+ radiation-induced corruption is detected.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import copy
10
+ from collections import deque
11
+ from typing import Any
12
+
13
+ import torch
14
+
15
+
16
+ class CheckpointManager:
17
+ """Manages model checkpoints for radiation fault recovery.
18
+
19
+ Keeps a sliding window of recent checkpoints so the model
20
+ can be rolled back to a known-good state.
21
+ """
22
+
23
+ def __init__(self, max_checkpoints: int = 3) -> None:
24
+ """Initialize checkpoint manager.
25
+
26
+ Args:
27
+ max_checkpoints: Maximum number of checkpoints to retain.
28
+ """
29
+ self._checkpoints: deque[dict[str, Any]] = deque(maxlen=max_checkpoints)
30
+ self._max = max_checkpoints
31
+
32
+ def save(self, model: torch.nn.Module, metadata: dict[str, Any] | None = None) -> int:
33
+ """Save a model checkpoint.
34
+
35
+ Args:
36
+ model: Model to checkpoint.
37
+ metadata: Optional metadata (e.g., step count, accuracy).
38
+
39
+ Returns:
40
+ Index of the saved checkpoint.
41
+ """
42
+ entry = {
43
+ "state_dict": copy.deepcopy(model.state_dict()),
44
+ "metadata": metadata or {},
45
+ }
46
+ self._checkpoints.append(entry)
47
+ return len(self._checkpoints) - 1
48
+
49
+ def restore(self, model: torch.nn.Module, index: int = -1) -> dict[str, Any]:
50
+ """Restore a model from a checkpoint.
51
+
52
+ Args:
53
+ model: Model to restore into.
54
+ index: Checkpoint index (-1 for most recent).
55
+
56
+ Returns:
57
+ Metadata associated with the restored checkpoint.
58
+
59
+ Raises:
60
+ IndexError: If no checkpoints are available.
61
+ """
62
+ if not self._checkpoints:
63
+ raise IndexError("No checkpoints available")
64
+
65
+ entry = self._checkpoints[index]
66
+ # No deepcopy needed — load_state_dict copies tensor data into
67
+ # the model's existing buffers without mutating the source dict.
68
+ model.load_state_dict(entry["state_dict"])
69
+ return entry["metadata"]
70
+
71
+ @property
72
+ def count(self) -> int:
73
+ """Number of stored checkpoints."""
74
+ return len(self._checkpoints)
75
+
76
+ def clear(self) -> None:
77
+ """Remove all checkpoints."""
78
+ self._checkpoints.clear()
@@ -0,0 +1,277 @@
1
+ """ML-aware radiation fault injection using PyTorch hooks.
2
+
3
+ Injects radiation-modeled faults into neural network inference:
4
+ 1. Weight SEU: Flip random bits in weight tensors (persistent memory fault)
5
+ 2. Activation SET: Flip random bits in activations during forward pass (transient)
6
+ 3. Stuck-at: Zero out weights in TID-degraded regions (permanent)
7
+
8
+ Key insight: MSB flips in the IEEE 754 exponent field cause catastrophic errors,
9
+ while LSB flips in the mantissa are often benign. Real radiation is uniform
10
+ across bit positions.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import copy
16
+ from dataclasses import dataclass
17
+ from typing import TYPE_CHECKING
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ if TYPE_CHECKING:
23
+ import pandas as pd
24
+
25
+ from space_ml_sim.environment.radiation import RadiationEnvironment
26
+ from space_ml_sim.models.chip_profiles import ChipProfile
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class FaultReport:
31
+ """Summary of faults injected during a simulation run."""
32
+
33
+ total_faults_injected: int = 0
34
+ weight_faults: int = 0
35
+ activation_faults: int = 0
36
+ layers_affected: tuple[str, ...] = ()
37
+ bit_positions_flipped: tuple[int, ...] = ()
38
+
39
+
40
+ class FaultInjector:
41
+ """Inject radiation-modeled faults into PyTorch model inference.
42
+
43
+ Works independently of the orbital mechanics module — can be used
44
+ standalone for fault tolerance research on any PyTorch model.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ rad_env: RadiationEnvironment,
50
+ chip_profile: ChipProfile,
51
+ seed: int | None = None,
52
+ ) -> None:
53
+ self.rad_env = rad_env
54
+ self.chip = chip_profile
55
+ self._hooks: list[torch.utils.hooks.RemovableHook] = []
56
+ self._rng = np.random.default_rng(seed)
57
+
58
+ @staticmethod
59
+ def flip_random_bits(tensor: torch.Tensor, num_flips: int) -> list[int]:
60
+ """Flip random bits in a float32 tensor's IEEE 754 binary representation.
61
+
62
+ Args:
63
+ tensor: A float32 tensor (modified in-place).
64
+ num_flips: Number of random bit flips to inject.
65
+
66
+ Returns:
67
+ List of bit positions that were flipped (0-31).
68
+ """
69
+ if num_flips == 0:
70
+ return []
71
+
72
+ # Ensure contiguous memory — view() fails on non-contiguous tensors
73
+ # (common with transposed weights in transformers)
74
+ if not tensor.is_contiguous():
75
+ tensor.data = tensor.contiguous()
76
+ flat = tensor.view(-1)
77
+ n = flat.numel()
78
+ if n == 0:
79
+ return []
80
+
81
+ # Map float dtype to corresponding integer dtype for bit manipulation
82
+ dtype_map = {
83
+ torch.float32: (torch.int32, 32),
84
+ torch.float16: (torch.int16, 16),
85
+ torch.bfloat16: (torch.int16, 16),
86
+ }
87
+ int_dtype, num_bits = dtype_map.get(flat.dtype, (torch.int32, 32))
88
+
89
+ indices = torch.randint(0, n, (num_flips,))
90
+ bit_positions = torch.randint(0, num_bits, (num_flips,))
91
+ masks = (1 << bit_positions).to(int_dtype)
92
+
93
+ int_view = flat.view(int_dtype)
94
+ for i in range(num_flips):
95
+ int_view[indices[i]] ^= masks[i]
96
+
97
+ return bit_positions.tolist()
98
+
99
+ def inject_weight_faults(
100
+ self,
101
+ model: torch.nn.Module,
102
+ num_faults: int | None = None,
103
+ inference_time_seconds: float = 0.001,
104
+ ) -> FaultReport:
105
+ """Inject SEU bit flips into model weights.
106
+
107
+ If num_faults is None, the count is sampled from a Poisson distribution
108
+ based on the radiation environment, chip profile, and inference time.
109
+
110
+ Args:
111
+ model: PyTorch model (weights modified in-place).
112
+ num_faults: Exact number of faults, or None for radiation-sampled.
113
+ inference_time_seconds: Duration used for Poisson rate calculation.
114
+
115
+ Returns:
116
+ FaultReport summarizing injected faults.
117
+ """
118
+ if num_faults is None:
119
+ total_weight_bits = sum(p.numel() * 32 for p in model.parameters())
120
+ expected = self.rad_env.base_seu_rate * total_weight_bits * inference_time_seconds
121
+ num_faults = int(self._rng.poisson(expected))
122
+
123
+ if num_faults == 0:
124
+ return FaultReport()
125
+
126
+ # Distribute faults across layers proportional to parameter count
127
+ params = [(name, p) for name, p in model.named_parameters() if p.requires_grad]
128
+ sizes = [p.numel() for _, p in params]
129
+ total = sum(sizes)
130
+ if total == 0:
131
+ return FaultReport()
132
+
133
+ all_layers: list[str] = []
134
+ all_bits: list[int] = []
135
+ total_injected = 0
136
+
137
+ for name, param in params:
138
+ layer_faults = max(1, int(num_faults * param.numel() / total))
139
+ layer_faults = min(layer_faults, num_faults - total_injected)
140
+ if layer_faults <= 0:
141
+ continue
142
+
143
+ with torch.no_grad():
144
+ bits = self.flip_random_bits(param.data, layer_faults)
145
+
146
+ total_injected += layer_faults
147
+ all_layers.append(name)
148
+ all_bits.extend(bits)
149
+
150
+ if total_injected >= num_faults:
151
+ break
152
+
153
+ return FaultReport(
154
+ total_faults_injected=total_injected,
155
+ weight_faults=total_injected,
156
+ layers_affected=tuple(all_layers),
157
+ bit_positions_flipped=tuple(all_bits),
158
+ )
159
+
160
+ def register_activation_hooks(
161
+ self,
162
+ model: torch.nn.Module,
163
+ fault_probability: float = 0.001,
164
+ ) -> None:
165
+ """Register forward hooks for transient activation fault injection.
166
+
167
+ Args:
168
+ model: PyTorch model to hook.
169
+ fault_probability: Probability of a bit flip per tensor element per forward pass.
170
+ """
171
+ self.remove_hooks()
172
+ rng = self._rng
173
+
174
+ def make_hook(layer_name: str):
175
+ def hook(module, input, output):
176
+ if isinstance(output, torch.Tensor) and output.numel() > 0:
177
+ num_faults = int(rng.binomial(output.numel(), fault_probability))
178
+ if num_faults > 0:
179
+ FaultInjector.flip_random_bits(output, num_faults)
180
+ return output
181
+
182
+ return hook
183
+
184
+ for name, module in model.named_modules():
185
+ # Hook leaf modules only
186
+ if len(list(module.children())) == 0:
187
+ h = module.register_forward_hook(make_hook(name))
188
+ self._hooks.append(h)
189
+
190
+ def remove_hooks(self) -> None:
191
+ """Remove all registered activation hooks."""
192
+ for h in self._hooks:
193
+ h.remove()
194
+ self._hooks.clear()
195
+
196
+ def sweep(
197
+ self,
198
+ model: torch.nn.Module,
199
+ dataloader: torch.utils.data.DataLoader,
200
+ fault_counts: list[int],
201
+ num_trials: int = 5,
202
+ ) -> "pd.DataFrame":
203
+ """Run fault injection sweep across multiple fault counts and trials.
204
+
205
+ For each fault count, runs num_trials with independent fault injections
206
+ and measures accuracy degradation.
207
+
208
+ Args:
209
+ model: Base PyTorch model (will be deep-copied per trial).
210
+ dataloader: Evaluation data loader.
211
+ fault_counts: List of fault counts to sweep.
212
+ num_trials: Number of independent trials per fault count.
213
+
214
+ Returns:
215
+ DataFrame with columns: fault_count, trial, accuracy,
216
+ top5_accuracy, critical_failure, faults_injected, layers_affected.
217
+ """
218
+ import pandas as pd
219
+
220
+ results: list[dict] = []
221
+
222
+ # Pre-allocate one working model and a clean state dict
223
+ # to avoid deepcopy per trial (saves ~50% memory)
224
+ test_model = copy.deepcopy(model)
225
+ clean_state = copy.deepcopy(model.state_dict())
226
+
227
+ for fc in fault_counts:
228
+ for trial in range(num_trials):
229
+ test_model.load_state_dict(copy.deepcopy(clean_state))
230
+ test_model.eval()
231
+
232
+ report = self.inject_weight_faults(test_model, num_faults=fc)
233
+
234
+ acc, top5 = _evaluate_model(test_model, dataloader)
235
+ results.append(
236
+ {
237
+ "fault_count": fc,
238
+ "trial": trial,
239
+ "accuracy": acc,
240
+ "top5_accuracy": top5,
241
+ "critical_failure": acc < 0.1,
242
+ "faults_injected": report.total_faults_injected,
243
+ "layers_affected": len(report.layers_affected),
244
+ }
245
+ )
246
+
247
+ return pd.DataFrame(results)
248
+
249
+
250
+ def _evaluate_model(
251
+ model: torch.nn.Module,
252
+ dataloader: torch.utils.data.DataLoader,
253
+ ) -> tuple[float, float]:
254
+ """Evaluate a model on a dataloader, returning (top1_accuracy, top5_accuracy)."""
255
+ correct = 0
256
+ correct_top5 = 0
257
+ total = 0
258
+
259
+ with torch.no_grad():
260
+ for images, labels in dataloader:
261
+ outputs = model(images)
262
+ _, predicted = outputs.max(1)
263
+ correct += predicted.eq(labels).sum().item()
264
+
265
+ if outputs.size(1) >= 5:
266
+ _, top5_pred = outputs.topk(5, dim=1)
267
+ for i in range(len(labels)):
268
+ if labels[i] in top5_pred[i]:
269
+ correct_top5 += 1
270
+ else:
271
+ correct_top5 += predicted.eq(labels).sum().item()
272
+
273
+ total += labels.size(0)
274
+
275
+ acc = correct / total if total > 0 else 0.0
276
+ top5 = correct_top5 / total if total > 0 else 0.0
277
+ return acc, top5
@@ -0,0 +1,88 @@
1
+ """Inference node abstraction for running ML models on a satellite."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+
7
+ import torch
8
+ from pydantic import BaseModel, Field
9
+
10
+ from space_ml_sim.models.chip_profiles import ChipProfile
11
+
12
+
13
+ class NodeStatus(str, Enum):
14
+ """Inference node operational status."""
15
+
16
+ IDLE = "idle"
17
+ RUNNING = "running"
18
+ FAULTED = "faulted"
19
+
20
+
21
+ class InferenceResult(BaseModel):
22
+ """Result of a single inference run."""
23
+
24
+ predictions: list[int] = Field(default_factory=list)
25
+ latency_ms: float = 0.0
26
+ faults_during_inference: int = 0
27
+ status: NodeStatus = NodeStatus.IDLE
28
+
29
+
30
+ class InferenceNode:
31
+ """Wraps a PyTorch model as a compute node on a satellite.
32
+
33
+ Tracks inference count, fault history, and thermal/power constraints.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: torch.nn.Module,
39
+ chip_profile: ChipProfile,
40
+ device: str = "cpu",
41
+ ) -> None:
42
+ self.model = model.to(device).eval()
43
+ self.chip_profile = chip_profile
44
+ self.device = device
45
+ self.status = NodeStatus.IDLE
46
+ self.inference_count = 0
47
+ self.total_faults = 0
48
+
49
+ def run_inference(self, inputs: torch.Tensor) -> InferenceResult:
50
+ """Run inference on a batch of inputs.
51
+
52
+ Args:
53
+ inputs: Input tensor batch.
54
+
55
+ Returns:
56
+ InferenceResult with predictions and metadata.
57
+ """
58
+ self.status = NodeStatus.RUNNING
59
+ try:
60
+ with torch.no_grad():
61
+ outputs = self.model(inputs.to(self.device))
62
+ predictions = outputs.argmax(dim=1).tolist()
63
+ except Exception:
64
+ self.status = NodeStatus.FAULTED
65
+ raise
66
+
67
+ self.inference_count += 1
68
+ self.status = NodeStatus.IDLE
69
+
70
+ return InferenceResult(
71
+ predictions=predictions,
72
+ status=NodeStatus.IDLE,
73
+ )
74
+
75
+ def can_run(self, power_available_w: float, temperature_c: float) -> bool:
76
+ """Check if the node has enough power and is within thermal limits.
77
+
78
+ Args:
79
+ power_available_w: Available power in watts.
80
+ temperature_c: Current temperature in Celsius.
81
+
82
+ Returns:
83
+ True if the node can safely run inference.
84
+ """
85
+ return (
86
+ power_available_w >= self.chip_profile.tdp_watts
87
+ and temperature_c <= self.chip_profile.max_temp_c
88
+ )