space-ml-sim 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- space_ml_sim/__init__.py +3 -0
- space_ml_sim/compute/__init__.py +32 -0
- space_ml_sim/compute/checkpoint.py +78 -0
- space_ml_sim/compute/fault_injector.py +277 -0
- space_ml_sim/compute/inference_node.py +88 -0
- space_ml_sim/compute/onnx_adapter.py +246 -0
- space_ml_sim/compute/quantization.py +207 -0
- space_ml_sim/compute/scheduler.py +81 -0
- space_ml_sim/compute/tmr.py +254 -0
- space_ml_sim/compute/transformer_fault.py +340 -0
- space_ml_sim/core/__init__.py +30 -0
- space_ml_sim/core/clock.py +50 -0
- space_ml_sim/core/constellation.py +252 -0
- space_ml_sim/core/orbit.py +315 -0
- space_ml_sim/core/satellite.py +125 -0
- space_ml_sim/core/tle.py +266 -0
- space_ml_sim/environment/__init__.py +21 -0
- space_ml_sim/environment/comms.py +39 -0
- space_ml_sim/environment/power.py +30 -0
- space_ml_sim/environment/radiation.py +150 -0
- space_ml_sim/environment/thermal.py +39 -0
- space_ml_sim/environment/timeline.py +281 -0
- space_ml_sim/metrics/__init__.py +6 -0
- space_ml_sim/metrics/performance.py +43 -0
- space_ml_sim/metrics/reliability.py +48 -0
- space_ml_sim/models/__init__.py +27 -0
- space_ml_sim/models/chip_profiles.py +123 -0
- space_ml_sim/models/rad_profiles.py +13 -0
- space_ml_sim/py.typed +0 -0
- space_ml_sim/viz/__init__.py +11 -0
- space_ml_sim/viz/heatmap.py +147 -0
- space_ml_sim/viz/plots.py +166 -0
- space_ml_sim-0.3.0.dist-info/METADATA +242 -0
- space_ml_sim-0.3.0.dist-info/RECORD +36 -0
- space_ml_sim-0.3.0.dist-info/WHEEL +4 -0
- space_ml_sim-0.3.0.dist-info/licenses/LICENSE +661 -0
space_ml_sim/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Compute module: fault injection, TMR, checkpointing, and scheduling."""
|
|
2
|
+
|
|
3
|
+
from space_ml_sim.compute.fault_injector import FaultInjector, FaultReport
|
|
4
|
+
from space_ml_sim.compute.transformer_fault import TransformerFaultInjector
|
|
5
|
+
from space_ml_sim.compute.tmr import TMRWrapper
|
|
6
|
+
from space_ml_sim.compute.checkpoint import CheckpointManager
|
|
7
|
+
from space_ml_sim.compute.scheduler import InferenceScheduler
|
|
8
|
+
from space_ml_sim.compute.quantization import (
|
|
9
|
+
quantize_model,
|
|
10
|
+
compare_quantization_resilience,
|
|
11
|
+
plot_quantization_comparison,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"FaultInjector",
|
|
16
|
+
"FaultReport",
|
|
17
|
+
"TransformerFaultInjector",
|
|
18
|
+
"TMRWrapper",
|
|
19
|
+
"CheckpointManager",
|
|
20
|
+
"InferenceScheduler",
|
|
21
|
+
"quantize_model",
|
|
22
|
+
"compare_quantization_resilience",
|
|
23
|
+
"plot_quantization_comparison",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
# Optional ONNX support — only available when 'onnx' and 'onnxruntime' are installed.
|
|
27
|
+
try:
|
|
28
|
+
from space_ml_sim.compute.onnx_adapter import OnnxModel, load_onnx
|
|
29
|
+
|
|
30
|
+
__all__ += ["OnnxModel", "load_onnx"]
|
|
31
|
+
except ImportError:
|
|
32
|
+
pass # onnx extras not installed
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Model checkpointing for fault recovery.
|
|
2
|
+
|
|
3
|
+
Periodically saves model state so it can be restored after
|
|
4
|
+
radiation-induced corruption is detected.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import copy
|
|
10
|
+
from collections import deque
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CheckpointManager:
|
|
17
|
+
"""Manages model checkpoints for radiation fault recovery.
|
|
18
|
+
|
|
19
|
+
Keeps a sliding window of recent checkpoints so the model
|
|
20
|
+
can be rolled back to a known-good state.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, max_checkpoints: int = 3) -> None:
|
|
24
|
+
"""Initialize checkpoint manager.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
max_checkpoints: Maximum number of checkpoints to retain.
|
|
28
|
+
"""
|
|
29
|
+
self._checkpoints: deque[dict[str, Any]] = deque(maxlen=max_checkpoints)
|
|
30
|
+
self._max = max_checkpoints
|
|
31
|
+
|
|
32
|
+
def save(self, model: torch.nn.Module, metadata: dict[str, Any] | None = None) -> int:
|
|
33
|
+
"""Save a model checkpoint.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model: Model to checkpoint.
|
|
37
|
+
metadata: Optional metadata (e.g., step count, accuracy).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Index of the saved checkpoint.
|
|
41
|
+
"""
|
|
42
|
+
entry = {
|
|
43
|
+
"state_dict": copy.deepcopy(model.state_dict()),
|
|
44
|
+
"metadata": metadata or {},
|
|
45
|
+
}
|
|
46
|
+
self._checkpoints.append(entry)
|
|
47
|
+
return len(self._checkpoints) - 1
|
|
48
|
+
|
|
49
|
+
def restore(self, model: torch.nn.Module, index: int = -1) -> dict[str, Any]:
|
|
50
|
+
"""Restore a model from a checkpoint.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: Model to restore into.
|
|
54
|
+
index: Checkpoint index (-1 for most recent).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Metadata associated with the restored checkpoint.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
IndexError: If no checkpoints are available.
|
|
61
|
+
"""
|
|
62
|
+
if not self._checkpoints:
|
|
63
|
+
raise IndexError("No checkpoints available")
|
|
64
|
+
|
|
65
|
+
entry = self._checkpoints[index]
|
|
66
|
+
# No deepcopy needed — load_state_dict copies tensor data into
|
|
67
|
+
# the model's existing buffers without mutating the source dict.
|
|
68
|
+
model.load_state_dict(entry["state_dict"])
|
|
69
|
+
return entry["metadata"]
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def count(self) -> int:
|
|
73
|
+
"""Number of stored checkpoints."""
|
|
74
|
+
return len(self._checkpoints)
|
|
75
|
+
|
|
76
|
+
def clear(self) -> None:
|
|
77
|
+
"""Remove all checkpoints."""
|
|
78
|
+
self._checkpoints.clear()
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""ML-aware radiation fault injection using PyTorch hooks.
|
|
2
|
+
|
|
3
|
+
Injects radiation-modeled faults into neural network inference:
|
|
4
|
+
1. Weight SEU: Flip random bits in weight tensors (persistent memory fault)
|
|
5
|
+
2. Activation SET: Flip random bits in activations during forward pass (transient)
|
|
6
|
+
3. Stuck-at: Zero out weights in TID-degraded regions (permanent)
|
|
7
|
+
|
|
8
|
+
Key insight: MSB flips in the IEEE 754 exponent field cause catastrophic errors,
|
|
9
|
+
while LSB flips in the mantissa are often benign. Real radiation is uniform
|
|
10
|
+
across bit positions.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
import pandas as pd
|
|
24
|
+
|
|
25
|
+
from space_ml_sim.environment.radiation import RadiationEnvironment
|
|
26
|
+
from space_ml_sim.models.chip_profiles import ChipProfile
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class FaultReport:
|
|
31
|
+
"""Summary of faults injected during a simulation run."""
|
|
32
|
+
|
|
33
|
+
total_faults_injected: int = 0
|
|
34
|
+
weight_faults: int = 0
|
|
35
|
+
activation_faults: int = 0
|
|
36
|
+
layers_affected: tuple[str, ...] = ()
|
|
37
|
+
bit_positions_flipped: tuple[int, ...] = ()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FaultInjector:
|
|
41
|
+
"""Inject radiation-modeled faults into PyTorch model inference.
|
|
42
|
+
|
|
43
|
+
Works independently of the orbital mechanics module — can be used
|
|
44
|
+
standalone for fault tolerance research on any PyTorch model.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
rad_env: RadiationEnvironment,
|
|
50
|
+
chip_profile: ChipProfile,
|
|
51
|
+
seed: int | None = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
self.rad_env = rad_env
|
|
54
|
+
self.chip = chip_profile
|
|
55
|
+
self._hooks: list[torch.utils.hooks.RemovableHook] = []
|
|
56
|
+
self._rng = np.random.default_rng(seed)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def flip_random_bits(tensor: torch.Tensor, num_flips: int) -> list[int]:
|
|
60
|
+
"""Flip random bits in a float32 tensor's IEEE 754 binary representation.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
tensor: A float32 tensor (modified in-place).
|
|
64
|
+
num_flips: Number of random bit flips to inject.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
List of bit positions that were flipped (0-31).
|
|
68
|
+
"""
|
|
69
|
+
if num_flips == 0:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
# Ensure contiguous memory — view() fails on non-contiguous tensors
|
|
73
|
+
# (common with transposed weights in transformers)
|
|
74
|
+
if not tensor.is_contiguous():
|
|
75
|
+
tensor.data = tensor.contiguous()
|
|
76
|
+
flat = tensor.view(-1)
|
|
77
|
+
n = flat.numel()
|
|
78
|
+
if n == 0:
|
|
79
|
+
return []
|
|
80
|
+
|
|
81
|
+
# Map float dtype to corresponding integer dtype for bit manipulation
|
|
82
|
+
dtype_map = {
|
|
83
|
+
torch.float32: (torch.int32, 32),
|
|
84
|
+
torch.float16: (torch.int16, 16),
|
|
85
|
+
torch.bfloat16: (torch.int16, 16),
|
|
86
|
+
}
|
|
87
|
+
int_dtype, num_bits = dtype_map.get(flat.dtype, (torch.int32, 32))
|
|
88
|
+
|
|
89
|
+
indices = torch.randint(0, n, (num_flips,))
|
|
90
|
+
bit_positions = torch.randint(0, num_bits, (num_flips,))
|
|
91
|
+
masks = (1 << bit_positions).to(int_dtype)
|
|
92
|
+
|
|
93
|
+
int_view = flat.view(int_dtype)
|
|
94
|
+
for i in range(num_flips):
|
|
95
|
+
int_view[indices[i]] ^= masks[i]
|
|
96
|
+
|
|
97
|
+
return bit_positions.tolist()
|
|
98
|
+
|
|
99
|
+
def inject_weight_faults(
|
|
100
|
+
self,
|
|
101
|
+
model: torch.nn.Module,
|
|
102
|
+
num_faults: int | None = None,
|
|
103
|
+
inference_time_seconds: float = 0.001,
|
|
104
|
+
) -> FaultReport:
|
|
105
|
+
"""Inject SEU bit flips into model weights.
|
|
106
|
+
|
|
107
|
+
If num_faults is None, the count is sampled from a Poisson distribution
|
|
108
|
+
based on the radiation environment, chip profile, and inference time.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
model: PyTorch model (weights modified in-place).
|
|
112
|
+
num_faults: Exact number of faults, or None for radiation-sampled.
|
|
113
|
+
inference_time_seconds: Duration used for Poisson rate calculation.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
FaultReport summarizing injected faults.
|
|
117
|
+
"""
|
|
118
|
+
if num_faults is None:
|
|
119
|
+
total_weight_bits = sum(p.numel() * 32 for p in model.parameters())
|
|
120
|
+
expected = self.rad_env.base_seu_rate * total_weight_bits * inference_time_seconds
|
|
121
|
+
num_faults = int(self._rng.poisson(expected))
|
|
122
|
+
|
|
123
|
+
if num_faults == 0:
|
|
124
|
+
return FaultReport()
|
|
125
|
+
|
|
126
|
+
# Distribute faults across layers proportional to parameter count
|
|
127
|
+
params = [(name, p) for name, p in model.named_parameters() if p.requires_grad]
|
|
128
|
+
sizes = [p.numel() for _, p in params]
|
|
129
|
+
total = sum(sizes)
|
|
130
|
+
if total == 0:
|
|
131
|
+
return FaultReport()
|
|
132
|
+
|
|
133
|
+
all_layers: list[str] = []
|
|
134
|
+
all_bits: list[int] = []
|
|
135
|
+
total_injected = 0
|
|
136
|
+
|
|
137
|
+
for name, param in params:
|
|
138
|
+
layer_faults = max(1, int(num_faults * param.numel() / total))
|
|
139
|
+
layer_faults = min(layer_faults, num_faults - total_injected)
|
|
140
|
+
if layer_faults <= 0:
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
bits = self.flip_random_bits(param.data, layer_faults)
|
|
145
|
+
|
|
146
|
+
total_injected += layer_faults
|
|
147
|
+
all_layers.append(name)
|
|
148
|
+
all_bits.extend(bits)
|
|
149
|
+
|
|
150
|
+
if total_injected >= num_faults:
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
return FaultReport(
|
|
154
|
+
total_faults_injected=total_injected,
|
|
155
|
+
weight_faults=total_injected,
|
|
156
|
+
layers_affected=tuple(all_layers),
|
|
157
|
+
bit_positions_flipped=tuple(all_bits),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def register_activation_hooks(
|
|
161
|
+
self,
|
|
162
|
+
model: torch.nn.Module,
|
|
163
|
+
fault_probability: float = 0.001,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Register forward hooks for transient activation fault injection.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
model: PyTorch model to hook.
|
|
169
|
+
fault_probability: Probability of a bit flip per tensor element per forward pass.
|
|
170
|
+
"""
|
|
171
|
+
self.remove_hooks()
|
|
172
|
+
rng = self._rng
|
|
173
|
+
|
|
174
|
+
def make_hook(layer_name: str):
|
|
175
|
+
def hook(module, input, output):
|
|
176
|
+
if isinstance(output, torch.Tensor) and output.numel() > 0:
|
|
177
|
+
num_faults = int(rng.binomial(output.numel(), fault_probability))
|
|
178
|
+
if num_faults > 0:
|
|
179
|
+
FaultInjector.flip_random_bits(output, num_faults)
|
|
180
|
+
return output
|
|
181
|
+
|
|
182
|
+
return hook
|
|
183
|
+
|
|
184
|
+
for name, module in model.named_modules():
|
|
185
|
+
# Hook leaf modules only
|
|
186
|
+
if len(list(module.children())) == 0:
|
|
187
|
+
h = module.register_forward_hook(make_hook(name))
|
|
188
|
+
self._hooks.append(h)
|
|
189
|
+
|
|
190
|
+
def remove_hooks(self) -> None:
|
|
191
|
+
"""Remove all registered activation hooks."""
|
|
192
|
+
for h in self._hooks:
|
|
193
|
+
h.remove()
|
|
194
|
+
self._hooks.clear()
|
|
195
|
+
|
|
196
|
+
def sweep(
|
|
197
|
+
self,
|
|
198
|
+
model: torch.nn.Module,
|
|
199
|
+
dataloader: torch.utils.data.DataLoader,
|
|
200
|
+
fault_counts: list[int],
|
|
201
|
+
num_trials: int = 5,
|
|
202
|
+
) -> "pd.DataFrame":
|
|
203
|
+
"""Run fault injection sweep across multiple fault counts and trials.
|
|
204
|
+
|
|
205
|
+
For each fault count, runs num_trials with independent fault injections
|
|
206
|
+
and measures accuracy degradation.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model: Base PyTorch model (will be deep-copied per trial).
|
|
210
|
+
dataloader: Evaluation data loader.
|
|
211
|
+
fault_counts: List of fault counts to sweep.
|
|
212
|
+
num_trials: Number of independent trials per fault count.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
DataFrame with columns: fault_count, trial, accuracy,
|
|
216
|
+
top5_accuracy, critical_failure, faults_injected, layers_affected.
|
|
217
|
+
"""
|
|
218
|
+
import pandas as pd
|
|
219
|
+
|
|
220
|
+
results: list[dict] = []
|
|
221
|
+
|
|
222
|
+
# Pre-allocate one working model and a clean state dict
|
|
223
|
+
# to avoid deepcopy per trial (saves ~50% memory)
|
|
224
|
+
test_model = copy.deepcopy(model)
|
|
225
|
+
clean_state = copy.deepcopy(model.state_dict())
|
|
226
|
+
|
|
227
|
+
for fc in fault_counts:
|
|
228
|
+
for trial in range(num_trials):
|
|
229
|
+
test_model.load_state_dict(copy.deepcopy(clean_state))
|
|
230
|
+
test_model.eval()
|
|
231
|
+
|
|
232
|
+
report = self.inject_weight_faults(test_model, num_faults=fc)
|
|
233
|
+
|
|
234
|
+
acc, top5 = _evaluate_model(test_model, dataloader)
|
|
235
|
+
results.append(
|
|
236
|
+
{
|
|
237
|
+
"fault_count": fc,
|
|
238
|
+
"trial": trial,
|
|
239
|
+
"accuracy": acc,
|
|
240
|
+
"top5_accuracy": top5,
|
|
241
|
+
"critical_failure": acc < 0.1,
|
|
242
|
+
"faults_injected": report.total_faults_injected,
|
|
243
|
+
"layers_affected": len(report.layers_affected),
|
|
244
|
+
}
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return pd.DataFrame(results)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _evaluate_model(
|
|
251
|
+
model: torch.nn.Module,
|
|
252
|
+
dataloader: torch.utils.data.DataLoader,
|
|
253
|
+
) -> tuple[float, float]:
|
|
254
|
+
"""Evaluate a model on a dataloader, returning (top1_accuracy, top5_accuracy)."""
|
|
255
|
+
correct = 0
|
|
256
|
+
correct_top5 = 0
|
|
257
|
+
total = 0
|
|
258
|
+
|
|
259
|
+
with torch.no_grad():
|
|
260
|
+
for images, labels in dataloader:
|
|
261
|
+
outputs = model(images)
|
|
262
|
+
_, predicted = outputs.max(1)
|
|
263
|
+
correct += predicted.eq(labels).sum().item()
|
|
264
|
+
|
|
265
|
+
if outputs.size(1) >= 5:
|
|
266
|
+
_, top5_pred = outputs.topk(5, dim=1)
|
|
267
|
+
for i in range(len(labels)):
|
|
268
|
+
if labels[i] in top5_pred[i]:
|
|
269
|
+
correct_top5 += 1
|
|
270
|
+
else:
|
|
271
|
+
correct_top5 += predicted.eq(labels).sum().item()
|
|
272
|
+
|
|
273
|
+
total += labels.size(0)
|
|
274
|
+
|
|
275
|
+
acc = correct / total if total > 0 else 0.0
|
|
276
|
+
top5 = correct_top5 / total if total > 0 else 0.0
|
|
277
|
+
return acc, top5
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Inference node abstraction for running ML models on a satellite."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from space_ml_sim.models.chip_profiles import ChipProfile
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NodeStatus(str, Enum):
|
|
14
|
+
"""Inference node operational status."""
|
|
15
|
+
|
|
16
|
+
IDLE = "idle"
|
|
17
|
+
RUNNING = "running"
|
|
18
|
+
FAULTED = "faulted"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class InferenceResult(BaseModel):
|
|
22
|
+
"""Result of a single inference run."""
|
|
23
|
+
|
|
24
|
+
predictions: list[int] = Field(default_factory=list)
|
|
25
|
+
latency_ms: float = 0.0
|
|
26
|
+
faults_during_inference: int = 0
|
|
27
|
+
status: NodeStatus = NodeStatus.IDLE
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class InferenceNode:
|
|
31
|
+
"""Wraps a PyTorch model as a compute node on a satellite.
|
|
32
|
+
|
|
33
|
+
Tracks inference count, fault history, and thermal/power constraints.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model: torch.nn.Module,
|
|
39
|
+
chip_profile: ChipProfile,
|
|
40
|
+
device: str = "cpu",
|
|
41
|
+
) -> None:
|
|
42
|
+
self.model = model.to(device).eval()
|
|
43
|
+
self.chip_profile = chip_profile
|
|
44
|
+
self.device = device
|
|
45
|
+
self.status = NodeStatus.IDLE
|
|
46
|
+
self.inference_count = 0
|
|
47
|
+
self.total_faults = 0
|
|
48
|
+
|
|
49
|
+
def run_inference(self, inputs: torch.Tensor) -> InferenceResult:
|
|
50
|
+
"""Run inference on a batch of inputs.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
inputs: Input tensor batch.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
InferenceResult with predictions and metadata.
|
|
57
|
+
"""
|
|
58
|
+
self.status = NodeStatus.RUNNING
|
|
59
|
+
try:
|
|
60
|
+
with torch.no_grad():
|
|
61
|
+
outputs = self.model(inputs.to(self.device))
|
|
62
|
+
predictions = outputs.argmax(dim=1).tolist()
|
|
63
|
+
except Exception:
|
|
64
|
+
self.status = NodeStatus.FAULTED
|
|
65
|
+
raise
|
|
66
|
+
|
|
67
|
+
self.inference_count += 1
|
|
68
|
+
self.status = NodeStatus.IDLE
|
|
69
|
+
|
|
70
|
+
return InferenceResult(
|
|
71
|
+
predictions=predictions,
|
|
72
|
+
status=NodeStatus.IDLE,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def can_run(self, power_available_w: float, temperature_c: float) -> bool:
|
|
76
|
+
"""Check if the node has enough power and is within thermal limits.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
power_available_w: Available power in watts.
|
|
80
|
+
temperature_c: Current temperature in Celsius.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
True if the node can safely run inference.
|
|
84
|
+
"""
|
|
85
|
+
return (
|
|
86
|
+
power_available_w >= self.chip_profile.tdp_watts
|
|
87
|
+
and temperature_c <= self.chip_profile.max_temp_c
|
|
88
|
+
)
|