sdasim 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdasim/__init__.py +90 -0
- sdasim/_compat.py +194 -0
- sdasim/_version.py +6 -0
- sdasim/batch.py +346 -0
- sdasim/calibrate.py +621 -0
- sdasim/cli.py +237 -0
- sdasim/config.py +162 -0
- sdasim/device.py +44 -0
- sdasim/empirical.py +682 -0
- sdasim/fpa.py +110 -0
- sdasim/io.py +174 -0
- sdasim/noise.py +50 -0
- sdasim/render.py +244 -0
- sdasim/sampler.py +312 -0
- sdasim/scene.py +484 -0
- sdasim/splat.py +426 -0
- sdasim/sstr7.py +422 -0
- sdasim/stars.py +161 -0
- sdasim/streak_vis.py +300 -0
- sdasim/targets.py +74 -0
- sdasim-0.2.0.dist-info/METADATA +327 -0
- sdasim-0.2.0.dist-info/RECORD +25 -0
- sdasim-0.2.0.dist-info/WHEEL +4 -0
- sdasim-0.2.0.dist-info/entry_points.txt +4 -0
- sdasim-0.2.0.dist-info/licenses/LICENSE +21 -0
sdasim/__init__.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""sdasim — Speed-optimized differentiable satellite scene simulator."""
|
|
2
|
+
|
|
3
|
+
from sdasim._version import __version__
|
|
4
|
+
from sdasim.batch import BatchRenderResult, render_scene_batch
|
|
5
|
+
from sdasim.config import (
|
|
6
|
+
SceneConfig,
|
|
7
|
+
SensorConfig,
|
|
8
|
+
StarFieldConfig,
|
|
9
|
+
StarMotionConfig,
|
|
10
|
+
TargetConfig,
|
|
11
|
+
load_config,
|
|
12
|
+
)
|
|
13
|
+
from sdasim.device import get_device, resolve_device, set_device
|
|
14
|
+
from sdasim.empirical import (
|
|
15
|
+
EmpiricalNoise,
|
|
16
|
+
EmpiricalPSF,
|
|
17
|
+
render_frame_empirical,
|
|
18
|
+
)
|
|
19
|
+
from sdasim.fpa import analog_to_digital, eod_to_sigma, mv_to_pe, pe_to_mv
|
|
20
|
+
from sdasim.noise import gaussian_noise, poisson_noise
|
|
21
|
+
from sdasim.render import expand_motion, render_frame
|
|
22
|
+
from sdasim.scene import Scene
|
|
23
|
+
from sdasim.splat import (
|
|
24
|
+
splat_elliptical_gaussian_batched,
|
|
25
|
+
splat_gaussians,
|
|
26
|
+
splat_gaussians_batched,
|
|
27
|
+
splat_moffat_batched,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name: str):
|
|
32
|
+
if name == "io":
|
|
33
|
+
import importlib
|
|
34
|
+
|
|
35
|
+
_io = importlib.import_module("sdasim.io")
|
|
36
|
+
globals()["io"] = _io # cache so __getattr__ isn't called again
|
|
37
|
+
return _io
|
|
38
|
+
if name in ("sampler", "SceneDistribution", "random_scene"):
|
|
39
|
+
import importlib
|
|
40
|
+
|
|
41
|
+
_sampler = importlib.import_module("sdasim.sampler")
|
|
42
|
+
globals()["sampler"] = _sampler
|
|
43
|
+
globals()["SceneDistribution"] = _sampler.SceneDistribution
|
|
44
|
+
globals()["random_scene"] = _sampler.random_scene
|
|
45
|
+
return globals()[name]
|
|
46
|
+
raise AttributeError(f"module 'sdasim' has no attribute {name!r}")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
"__version__",
|
|
51
|
+
# Device
|
|
52
|
+
"get_device",
|
|
53
|
+
"set_device",
|
|
54
|
+
"resolve_device",
|
|
55
|
+
# Core
|
|
56
|
+
"splat_gaussians",
|
|
57
|
+
"splat_gaussians_batched",
|
|
58
|
+
"splat_moffat_batched",
|
|
59
|
+
"splat_elliptical_gaussian_batched",
|
|
60
|
+
"poisson_noise",
|
|
61
|
+
"gaussian_noise",
|
|
62
|
+
"analog_to_digital",
|
|
63
|
+
"mv_to_pe",
|
|
64
|
+
"pe_to_mv",
|
|
65
|
+
"eod_to_sigma",
|
|
66
|
+
# Render
|
|
67
|
+
"render_frame",
|
|
68
|
+
"expand_motion",
|
|
69
|
+
"render_scene_batch",
|
|
70
|
+
"BatchRenderResult",
|
|
71
|
+
# Empirical (opt-in)
|
|
72
|
+
"EmpiricalPSF",
|
|
73
|
+
"EmpiricalNoise",
|
|
74
|
+
"render_frame_empirical",
|
|
75
|
+
# Scene
|
|
76
|
+
"Scene",
|
|
77
|
+
# Config
|
|
78
|
+
"SceneConfig",
|
|
79
|
+
"SensorConfig",
|
|
80
|
+
"StarFieldConfig",
|
|
81
|
+
"StarMotionConfig",
|
|
82
|
+
"TargetConfig",
|
|
83
|
+
"load_config",
|
|
84
|
+
# I/O (lazy)
|
|
85
|
+
"io",
|
|
86
|
+
# Sampler (lazy)
|
|
87
|
+
"sampler",
|
|
88
|
+
"SceneDistribution",
|
|
89
|
+
"random_scene",
|
|
90
|
+
]
|
sdasim/_compat.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""satsim configuration converter.
|
|
2
|
+
|
|
3
|
+
Maps satsim config dicts to sdasim SceneConfig. If the satsim config uses
|
|
4
|
+
$sample/$ref/$generator, calls satsim.config.loading.realize() first
|
|
5
|
+
(requires satsim installed). For flat configs, no satsim dependency needed.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from sdasim.config import (
|
|
13
|
+
SceneConfig,
|
|
14
|
+
SensorConfig,
|
|
15
|
+
StarFieldConfig,
|
|
16
|
+
StarMotionConfig,
|
|
17
|
+
TargetConfig,
|
|
18
|
+
)
|
|
19
|
+
from sdasim.fpa import eod_to_sigma
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _has_dynamic_keys(d: Any) -> bool:
|
|
23
|
+
"""Check if a dict tree contains satsim dynamic keys ($sample, $ref, etc.)."""
|
|
24
|
+
if isinstance(d, dict):
|
|
25
|
+
for k, v in d.items():
|
|
26
|
+
if k.startswith("$"):
|
|
27
|
+
return True
|
|
28
|
+
if _has_dynamic_keys(v):
|
|
29
|
+
return True
|
|
30
|
+
elif isinstance(d, list):
|
|
31
|
+
return any(_has_dynamic_keys(item) for item in d)
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get(d: dict, *keys: str, default: Any = None) -> Any:
|
|
36
|
+
"""Safely traverse nested dict."""
|
|
37
|
+
for key in keys:
|
|
38
|
+
if not isinstance(d, dict):
|
|
39
|
+
return default
|
|
40
|
+
d = d.get(key, default)
|
|
41
|
+
if d is default:
|
|
42
|
+
return default
|
|
43
|
+
return d
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def from_satsim_config(satsim_dict: dict, seed: int | None = None) -> SceneConfig:
|
|
47
|
+
"""Convert a satsim configuration dict to a sdasim SceneConfig.
|
|
48
|
+
|
|
49
|
+
If the config contains dynamic keys ($sample, $ref, $generator), this
|
|
50
|
+
function calls satsim.config.loading.realize() first.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
satsim_dict: satsim configuration dictionary.
|
|
54
|
+
seed: Random seed for config resolution.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
SceneConfig ready for Scene construction.
|
|
58
|
+
"""
|
|
59
|
+
cfg = satsim_dict
|
|
60
|
+
|
|
61
|
+
# Resolve dynamic keys if present
|
|
62
|
+
if _has_dynamic_keys(cfg):
|
|
63
|
+
try:
|
|
64
|
+
from satsim.config.loading import realize
|
|
65
|
+
except ImportError:
|
|
66
|
+
raise ImportError(
|
|
67
|
+
"satsim config contains dynamic keys ($sample/$ref/$generator) "
|
|
68
|
+
"but satsim is not installed. Install it or use a flat config."
|
|
69
|
+
)
|
|
70
|
+
cfg = realize(cfg, seed=seed)
|
|
71
|
+
|
|
72
|
+
# Extract FPA / sensor config
|
|
73
|
+
fpa = _get(cfg, "fpa", default={})
|
|
74
|
+
height = _get(fpa, "height", default=512)
|
|
75
|
+
width = _get(fpa, "width", default=512)
|
|
76
|
+
y_fov = _get(fpa, "y_fov", default=0.5)
|
|
77
|
+
x_fov = _get(fpa, "x_fov", default=0.5)
|
|
78
|
+
|
|
79
|
+
time_cfg = _get(fpa, "time", default={})
|
|
80
|
+
exposure = _get(time_cfg, "exposure", default=2.0)
|
|
81
|
+
gap = _get(time_cfg, "gap", default=0.5)
|
|
82
|
+
|
|
83
|
+
num_frames = _get(fpa, "num_frames", default=1)
|
|
84
|
+
zeropoint = _get(fpa, "zeropoint", default=23.5)
|
|
85
|
+
|
|
86
|
+
# PSF: convert from EOD if Gaussian PSF specified
|
|
87
|
+
psf_cfg = _get(fpa, "psf", default={})
|
|
88
|
+
osf = _get(fpa, "s_osf", default=1)
|
|
89
|
+
if "eod" in psf_cfg:
|
|
90
|
+
psf_sigma = eod_to_sigma(psf_cfg["eod"], osf=1.0) # native resolution
|
|
91
|
+
elif "sigma" in psf_cfg:
|
|
92
|
+
psf_sigma = psf_cfg["sigma"] / osf # convert from oversampled to native
|
|
93
|
+
else:
|
|
94
|
+
psf_sigma = 1.5
|
|
95
|
+
|
|
96
|
+
# Noise
|
|
97
|
+
noise_cfg = _get(fpa, "noise", default={})
|
|
98
|
+
read_noise = _get(noise_cfg, "read", default=10.0)
|
|
99
|
+
electronic_noise = _get(noise_cfg, "electronic", default=5.0)
|
|
100
|
+
|
|
101
|
+
# A2D
|
|
102
|
+
a2d_cfg = _get(fpa, "a2d", default={})
|
|
103
|
+
gain = _get(a2d_cfg, "gain", default=8.0)
|
|
104
|
+
fwc = _get(a2d_cfg, "fwc", default=100000.0)
|
|
105
|
+
a2d_bias = _get(a2d_cfg, "bias", default=500.0)
|
|
106
|
+
a2d_dtype = _get(a2d_cfg, "dtype", default="uint16")
|
|
107
|
+
|
|
108
|
+
# Dark current
|
|
109
|
+
dark_current = _get(fpa, "dark_current", default=10.0)
|
|
110
|
+
|
|
111
|
+
# Background
|
|
112
|
+
bg_cfg = _get(cfg, "background", default={})
|
|
113
|
+
background_mv = _get(bg_cfg, "galactic", "mv", default=21.0)
|
|
114
|
+
if background_mv is None:
|
|
115
|
+
background_mv = 21.0
|
|
116
|
+
|
|
117
|
+
# Bias
|
|
118
|
+
bias = _get(fpa, "bias", default=50.0)
|
|
119
|
+
|
|
120
|
+
sensor = SensorConfig(
|
|
121
|
+
height=height,
|
|
122
|
+
width=width,
|
|
123
|
+
y_fov=y_fov,
|
|
124
|
+
x_fov=x_fov,
|
|
125
|
+
exposure=exposure,
|
|
126
|
+
gap=gap,
|
|
127
|
+
num_frames=num_frames,
|
|
128
|
+
zeropoint=zeropoint,
|
|
129
|
+
psf_sigma=psf_sigma,
|
|
130
|
+
dark_current=dark_current,
|
|
131
|
+
read_noise=read_noise,
|
|
132
|
+
electronic_noise=electronic_noise,
|
|
133
|
+
background_mv=background_mv,
|
|
134
|
+
bias=bias,
|
|
135
|
+
gain=gain,
|
|
136
|
+
fwc=fwc,
|
|
137
|
+
a2d_bias=a2d_bias,
|
|
138
|
+
a2d_dtype=a2d_dtype,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Stars
|
|
142
|
+
geom = _get(cfg, "geometry", default={})
|
|
143
|
+
stars_cfg = _get(geom, "stars", default={})
|
|
144
|
+
|
|
145
|
+
star_mode = _get(stars_cfg, "mode", default="bins")
|
|
146
|
+
if star_mode == "random" or "mv" in stars_cfg:
|
|
147
|
+
# Random bins mode
|
|
148
|
+
mv_bins = _get(stars_cfg, "mv", "bins", default=[6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
|
149
|
+
density = _get(stars_cfg, "mv", "density", default=[1.0] * (len(mv_bins) - 1))
|
|
150
|
+
stars = StarFieldConfig(mode="bins", mv_bins=mv_bins, density=density)
|
|
151
|
+
elif star_mode == "sstr7":
|
|
152
|
+
stars = StarFieldConfig(
|
|
153
|
+
mode="sstr7",
|
|
154
|
+
catalog_path=_get(stars_cfg, "path", default=None),
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
stars = StarFieldConfig(mode="bins")
|
|
158
|
+
|
|
159
|
+
# Star motion
|
|
160
|
+
star_motion_cfg = _get(geom, "star_motion", default={})
|
|
161
|
+
|
|
162
|
+
rotation = _get(star_motion_cfg, "rotation", default=0.0)
|
|
163
|
+
translation = _get(star_motion_cfg, "translation", default=[0.0, 0.0])
|
|
164
|
+
t_osf = _get(cfg, "t_osf", default=_get(fpa, "t_osf", default=100))
|
|
165
|
+
|
|
166
|
+
star_motion = StarMotionConfig(
|
|
167
|
+
rotation=rotation,
|
|
168
|
+
translation=translation,
|
|
169
|
+
temporal_osf=t_osf,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Targets (observation objects)
|
|
173
|
+
obs_list = _get(geom, "obs", default=[])
|
|
174
|
+
if not isinstance(obs_list, list):
|
|
175
|
+
obs_list = [obs_list]
|
|
176
|
+
|
|
177
|
+
targets = []
|
|
178
|
+
for obs in obs_list:
|
|
179
|
+
origin = _get(obs, "origin", default=[0.5, 0.5])
|
|
180
|
+
velocity = _get(obs, "velocity", default=[0.0, 0.0])
|
|
181
|
+
mv = _get(obs, "mv", default=12.0)
|
|
182
|
+
mode = _get(obs, "mode", default="line")
|
|
183
|
+
targets.append(TargetConfig(mode=mode, origin=origin, velocity=velocity, mv=mv))
|
|
184
|
+
|
|
185
|
+
return SceneConfig(
|
|
186
|
+
sensor=sensor,
|
|
187
|
+
stars=stars,
|
|
188
|
+
star_motion=star_motion,
|
|
189
|
+
targets=targets,
|
|
190
|
+
seed=seed,
|
|
191
|
+
device="auto",
|
|
192
|
+
enable_shot_noise=_get(noise_cfg, "photon", default=True),
|
|
193
|
+
enable_read_noise=read_noise > 0 or electronic_noise > 0,
|
|
194
|
+
)
|
sdasim/_version.py
ADDED
sdasim/batch.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""Batched multi-scene rendering.
|
|
2
|
+
|
|
3
|
+
Fuses N heterogeneous Scenes into a single kernel launch. Each scene can have
|
|
4
|
+
its own star catalog, its own PSF sigma, its own noise/gain/background/etc.
|
|
5
|
+
|
|
6
|
+
Restrictions in this first pass (matches what zerosda's pretraining loop
|
|
7
|
+
needs):
|
|
8
|
+
- All scenes must share (height, width) and a2d_dtype.
|
|
9
|
+
- Only supports frame_idx=0 and mode=None. Rate_sidereal mode-dispatch
|
|
10
|
+
is left to the non-batched Scene.render() path.
|
|
11
|
+
- Star motion (rate tracking) is supported per-scene via expand_motion.
|
|
12
|
+
- Targets are rendered with their own velocities, also via expand_motion.
|
|
13
|
+
|
|
14
|
+
Key win: all stars and all targets across all B scenes are splatted with two
|
|
15
|
+
kernel launches total (one for stars, one for targets), instead of 2*B.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import math
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from typing import Sequence
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
from torch import Tensor
|
|
26
|
+
|
|
27
|
+
from sdasim.fpa import MAX_PIXEL_VALUE
|
|
28
|
+
from sdasim.noise import poisson_noise
|
|
29
|
+
from sdasim.render import expand_motion
|
|
30
|
+
from sdasim.scene import Scene
|
|
31
|
+
from sdasim.splat import splat_gaussians_batched
|
|
32
|
+
from sdasim.targets import compute_target_positions
|
|
33
|
+
|
|
34
|
+
_CACHE_ATTR = "_batch_source_cache"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _ensure_source_cache(scene: Scene) -> dict:
|
|
38
|
+
"""Lazily cache per-scene source tensors for batched rendering.
|
|
39
|
+
|
|
40
|
+
This precomputes the motion-expanded star and target source lists plus the
|
|
41
|
+
per-scene scalar params used downstream. All of this is deterministic given
|
|
42
|
+
the scene's config and doesn't change across reuses, so computing it once
|
|
43
|
+
at pool-fill time moves a lot of per-batch Python work out of the hot path.
|
|
44
|
+
"""
|
|
45
|
+
cache = getattr(scene, _CACHE_ATTR, None)
|
|
46
|
+
if cache is not None:
|
|
47
|
+
return cache
|
|
48
|
+
|
|
49
|
+
sensor = scene.sensor
|
|
50
|
+
sm = scene.config.star_motion
|
|
51
|
+
device = scene.device
|
|
52
|
+
center = (sensor.height / 2.0, sensor.width / 2.0)
|
|
53
|
+
|
|
54
|
+
# --- Stars with within-frame motion blur ---
|
|
55
|
+
sp = scene.star_positions
|
|
56
|
+
si = scene.star_intensities
|
|
57
|
+
has_motion = sm.rotation != 0.0 or sm.translation[0] != 0.0 or sm.translation[1] != 0.0
|
|
58
|
+
if has_motion and sm.temporal_osf > 1 and sp.shape[0] > 0:
|
|
59
|
+
star_pos, star_int = expand_motion(
|
|
60
|
+
sp,
|
|
61
|
+
si,
|
|
62
|
+
sm.translation,
|
|
63
|
+
sm.rotation,
|
|
64
|
+
0.0,
|
|
65
|
+
sensor.exposure,
|
|
66
|
+
sm.temporal_osf,
|
|
67
|
+
center,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
star_pos, star_int = sp, si
|
|
71
|
+
|
|
72
|
+
# --- Targets with per-target motion blur ---
|
|
73
|
+
tp, ti, tv = compute_target_positions(scene.config.targets, sensor, 0, device)
|
|
74
|
+
if tp.shape[0] > 0:
|
|
75
|
+
max_speed = float(tv.abs().max().item()) if tv.numel() else 0.0
|
|
76
|
+
streak_px = max_speed * sensor.exposure
|
|
77
|
+
tgt_osf = max(1, int(streak_px * 2))
|
|
78
|
+
if tgt_osf > 1:
|
|
79
|
+
tgt_pos, tgt_int = expand_motion(
|
|
80
|
+
tp,
|
|
81
|
+
ti,
|
|
82
|
+
tv,
|
|
83
|
+
0.0,
|
|
84
|
+
0.0,
|
|
85
|
+
sensor.exposure,
|
|
86
|
+
tgt_osf,
|
|
87
|
+
center,
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
tgt_pos, tgt_int = tp, ti
|
|
91
|
+
else:
|
|
92
|
+
tgt_pos = torch.zeros(0, 2, device=device, dtype=torch.float32)
|
|
93
|
+
tgt_int = torch.zeros(0, device=device, dtype=torch.float32)
|
|
94
|
+
|
|
95
|
+
cache = {
|
|
96
|
+
"star_pos": star_pos,
|
|
97
|
+
"star_int": star_int,
|
|
98
|
+
"star_count": int(star_pos.shape[0]),
|
|
99
|
+
"tgt_pos": tgt_pos,
|
|
100
|
+
"tgt_int": tgt_int,
|
|
101
|
+
"tgt_count": int(tgt_pos.shape[0]),
|
|
102
|
+
"psf_sigma": float(sensor.psf_sigma),
|
|
103
|
+
"background_pe": float(scene.background_pe),
|
|
104
|
+
"dark_current_pe": float(scene.dark_current_pe),
|
|
105
|
+
"bias_pe": float(scene.bias_pe),
|
|
106
|
+
"read_noise": float(sensor.read_noise),
|
|
107
|
+
"electronic_noise": float(sensor.electronic_noise),
|
|
108
|
+
"gain": float(sensor.gain),
|
|
109
|
+
"fwc": float(sensor.fwc),
|
|
110
|
+
"a2d_bias": float(sensor.a2d_bias),
|
|
111
|
+
}
|
|
112
|
+
setattr(scene, _CACHE_ATTR, cache)
|
|
113
|
+
return cache
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class BatchRenderResult:
|
|
118
|
+
"""Output of render_scene_batch.
|
|
119
|
+
|
|
120
|
+
Attributes:
|
|
121
|
+
digital: [B, H, W] integer-ADU digital image.
|
|
122
|
+
star_signal: [B, H, W] pre-noise star-only PE.
|
|
123
|
+
target_signal: [B, H, W] pre-noise target-only PE.
|
|
124
|
+
star_positions_per_frame: list of [N_b, 2] per-frame pre-expansion star
|
|
125
|
+
positions (same as each scene's Scene.star_positions). Useful for
|
|
126
|
+
generating training heatmaps without running the splat again.
|
|
127
|
+
num_stars: [B] int counts of sources in each frame (pre-motion-expand).
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
digital: Tensor
|
|
131
|
+
star_signal: Tensor
|
|
132
|
+
target_signal: Tensor
|
|
133
|
+
star_positions_per_frame: list[Tensor]
|
|
134
|
+
num_stars: Tensor
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _collect_cached(
|
|
138
|
+
scenes: Sequence[Scene],
|
|
139
|
+
device: torch.device,
|
|
140
|
+
kind: str,
|
|
141
|
+
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
142
|
+
"""Gather pre-computed (via _ensure_source_cache) star or target sources.
|
|
143
|
+
|
|
144
|
+
This collects already-cached per-scene tensors and builds (positions,
|
|
145
|
+
intensities, frame_ids, per_source_sigma) via a single torch.cat +
|
|
146
|
+
repeat_interleave pair. Extremely cheap because the motion expansion
|
|
147
|
+
work was done once at cache-fill time.
|
|
148
|
+
"""
|
|
149
|
+
pos_key = f"{kind}_pos"
|
|
150
|
+
int_key = f"{kind}_int"
|
|
151
|
+
cnt_key = f"{kind}_count"
|
|
152
|
+
|
|
153
|
+
pos_chunks: list[Tensor] = []
|
|
154
|
+
int_chunks: list[Tensor] = []
|
|
155
|
+
counts: list[int] = []
|
|
156
|
+
sigmas: list[float] = []
|
|
157
|
+
fids: list[int] = []
|
|
158
|
+
|
|
159
|
+
for i, scene in enumerate(scenes):
|
|
160
|
+
c = _ensure_source_cache(scene)
|
|
161
|
+
n = c[cnt_key]
|
|
162
|
+
if n == 0:
|
|
163
|
+
continue
|
|
164
|
+
pos_chunks.append(c[pos_key])
|
|
165
|
+
int_chunks.append(c[int_key])
|
|
166
|
+
counts.append(n)
|
|
167
|
+
sigmas.append(c["psf_sigma"])
|
|
168
|
+
fids.append(i)
|
|
169
|
+
|
|
170
|
+
if not pos_chunks:
|
|
171
|
+
z2 = torch.zeros(0, 2, dtype=torch.float32, device=device)
|
|
172
|
+
z1 = torch.zeros(0, dtype=torch.float32, device=device)
|
|
173
|
+
zi = torch.zeros(0, dtype=torch.long, device=device)
|
|
174
|
+
return z2, z1, zi, z1
|
|
175
|
+
|
|
176
|
+
positions = torch.cat(pos_chunks, dim=0)
|
|
177
|
+
intensities = torch.cat(int_chunks, dim=0)
|
|
178
|
+
counts_t = torch.tensor(counts, dtype=torch.long, device=device)
|
|
179
|
+
frame_ids = torch.repeat_interleave(
|
|
180
|
+
torch.tensor(fids, dtype=torch.long, device=device), counts_t
|
|
181
|
+
)
|
|
182
|
+
per_source_sigma = torch.repeat_interleave(
|
|
183
|
+
torch.tensor(sigmas, dtype=torch.float32, device=device), counts_t
|
|
184
|
+
)
|
|
185
|
+
return positions, intensities, frame_ids, per_source_sigma
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _render_batch_empirical(scenes: Sequence[Scene]) -> BatchRenderResult:
|
|
189
|
+
"""Empirical-mode batch: each scene has its own sampled PSF kernel + per-scene FFT, so it
|
|
190
|
+
can't fuse into a single splat like the Gaussian path. Render each via the full empirical
|
|
191
|
+
pipeline (Scene.render_signals) and stack. FFTs still run on-GPU per scene."""
|
|
192
|
+
device = scenes[0].device
|
|
193
|
+
digs, ss, ts, pf = [], [], [], []
|
|
194
|
+
for s in scenes:
|
|
195
|
+
d, star_sig, tgt_sig, _ = s.render_signals(0)
|
|
196
|
+
digs.append(d)
|
|
197
|
+
ss.append(star_sig)
|
|
198
|
+
ts.append(tgt_sig)
|
|
199
|
+
pf.append(s.star_positions)
|
|
200
|
+
num_stars = torch.tensor([p.shape[0] for p in pf], dtype=torch.long, device=device)
|
|
201
|
+
return BatchRenderResult(torch.stack(digs), torch.stack(ss), torch.stack(ts), pf, num_stars)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def render_scene_batch(scenes: Sequence[Scene]) -> BatchRenderResult:
|
|
205
|
+
"""Render N scenes in a single fused pass.
|
|
206
|
+
|
|
207
|
+
Each scene is treated as its own telescope/exposure with its own PSF,
|
|
208
|
+
background, read noise, gain, etc. All scenes contribute to one
|
|
209
|
+
(B, H, W) output via per-source frame_id tagging on the splats.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
scenes: Sequence of Scene objects. All must share (height, width,
|
|
213
|
+
a2d_dtype, device). Scene.config.mode must be None for each scene.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
BatchRenderResult with digital, star_signal, target_signal, and the
|
|
217
|
+
per-frame pre-expansion star positions (handy for building heatmaps).
|
|
218
|
+
"""
|
|
219
|
+
if len(scenes) == 0:
|
|
220
|
+
raise ValueError("render_scene_batch requires at least one scene")
|
|
221
|
+
|
|
222
|
+
ref = scenes[0]
|
|
223
|
+
device = ref.device
|
|
224
|
+
H = ref.sensor.height
|
|
225
|
+
W = ref.sensor.width
|
|
226
|
+
a2d_dtype = ref.sensor.a2d_dtype
|
|
227
|
+
B = len(scenes)
|
|
228
|
+
|
|
229
|
+
for s in scenes:
|
|
230
|
+
if s.sensor.height != H or s.sensor.width != W:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"render_scene_batch requires all scenes to share (height, width); "
|
|
233
|
+
f"got {(s.sensor.height, s.sensor.width)} vs {(H, W)}"
|
|
234
|
+
)
|
|
235
|
+
if s.sensor.a2d_dtype != a2d_dtype:
|
|
236
|
+
raise ValueError("render_scene_batch requires all scenes to share a2d_dtype")
|
|
237
|
+
if s.config.mode is not None:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"render_scene_batch only supports mode=None (sidereal/rate_track); "
|
|
240
|
+
f"got mode={s.config.mode}"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Empirical scenes use per-scene sampled kernels + FFT -> render per-scene and stack.
|
|
244
|
+
if any(
|
|
245
|
+
s.sensor.psf_model == "empirical" or s.sensor.noise_model == "empirical" for s in scenes
|
|
246
|
+
):
|
|
247
|
+
return _render_batch_empirical(scenes)
|
|
248
|
+
|
|
249
|
+
# Ensure caches exist (lazy, idempotent on pool scenes).
|
|
250
|
+
for s in scenes:
|
|
251
|
+
_ensure_source_cache(s)
|
|
252
|
+
|
|
253
|
+
per_frame_positions = [s.star_positions for s in scenes]
|
|
254
|
+
star_pos, star_int, star_fid, star_sig = _collect_cached(scenes, device, "star")
|
|
255
|
+
tgt_pos, tgt_int, tgt_fid, tgt_sig = _collect_cached(scenes, device, "tgt")
|
|
256
|
+
|
|
257
|
+
# Single fused splat per kind. splat_gaussians_batched handles the empty
|
|
258
|
+
# source case by returning zeros of the right shape.
|
|
259
|
+
star_signal = splat_gaussians_batched(
|
|
260
|
+
B,
|
|
261
|
+
H,
|
|
262
|
+
W,
|
|
263
|
+
star_pos,
|
|
264
|
+
star_int,
|
|
265
|
+
star_fid,
|
|
266
|
+
star_sig,
|
|
267
|
+
)
|
|
268
|
+
target_signal = splat_gaussians_batched(
|
|
269
|
+
B,
|
|
270
|
+
H,
|
|
271
|
+
W,
|
|
272
|
+
tgt_pos,
|
|
273
|
+
tgt_int,
|
|
274
|
+
tgt_fid,
|
|
275
|
+
tgt_sig,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
signal = star_signal + target_signal
|
|
279
|
+
|
|
280
|
+
# Per-frame scalar params from cache. Build a single tensor per param.
|
|
281
|
+
caches = [getattr(s, _CACHE_ATTR) for s in scenes]
|
|
282
|
+
bg = torch.tensor(
|
|
283
|
+
[c["background_pe"] for c in caches],
|
|
284
|
+
dtype=torch.float32,
|
|
285
|
+
device=device,
|
|
286
|
+
).view(B, 1, 1)
|
|
287
|
+
dc = torch.tensor(
|
|
288
|
+
[c["dark_current_pe"] for c in caches],
|
|
289
|
+
dtype=torch.float32,
|
|
290
|
+
device=device,
|
|
291
|
+
).view(B, 1, 1)
|
|
292
|
+
bias = torch.tensor(
|
|
293
|
+
[c["bias_pe"] for c in caches],
|
|
294
|
+
dtype=torch.float32,
|
|
295
|
+
device=device,
|
|
296
|
+
).view(B, 1, 1)
|
|
297
|
+
|
|
298
|
+
signal = signal + bg + dc + bias
|
|
299
|
+
|
|
300
|
+
enable_shot = all(s.config.enable_shot_noise for s in scenes)
|
|
301
|
+
enable_read = all(s.config.enable_read_noise for s in scenes)
|
|
302
|
+
if enable_shot:
|
|
303
|
+
signal = poisson_noise(signal)
|
|
304
|
+
if enable_read:
|
|
305
|
+
rn_sigma = torch.tensor(
|
|
306
|
+
[math.sqrt(c["read_noise"] ** 2 + c["electronic_noise"] ** 2) for c in caches],
|
|
307
|
+
dtype=torch.float32,
|
|
308
|
+
device=device,
|
|
309
|
+
).view(B, 1, 1)
|
|
310
|
+
signal = signal + rn_sigma * torch.randn_like(signal)
|
|
311
|
+
|
|
312
|
+
a2d_bias_t = torch.tensor(
|
|
313
|
+
[c["a2d_bias"] for c in caches],
|
|
314
|
+
dtype=torch.float32,
|
|
315
|
+
device=device,
|
|
316
|
+
).view(B, 1, 1)
|
|
317
|
+
fwc_t = torch.tensor(
|
|
318
|
+
[c["fwc"] for c in caches],
|
|
319
|
+
dtype=torch.float32,
|
|
320
|
+
device=device,
|
|
321
|
+
).view(B, 1, 1)
|
|
322
|
+
gain_t = torch.tensor(
|
|
323
|
+
[c["gain"] for c in caches],
|
|
324
|
+
dtype=torch.float32,
|
|
325
|
+
device=device,
|
|
326
|
+
).view(B, 1, 1)
|
|
327
|
+
|
|
328
|
+
biased = (signal + a2d_bias_t).clamp(min=0.0)
|
|
329
|
+
biased = torch.minimum(biased, fwc_t)
|
|
330
|
+
dn = torch.floor(biased / gain_t)
|
|
331
|
+
max_val = MAX_PIXEL_VALUE.get(a2d_dtype, 65535.0)
|
|
332
|
+
dn = dn.clamp(min=0.0, max=max_val)
|
|
333
|
+
|
|
334
|
+
num_stars = torch.tensor(
|
|
335
|
+
[p.shape[0] for p in per_frame_positions],
|
|
336
|
+
dtype=torch.long,
|
|
337
|
+
device=device,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
return BatchRenderResult(
|
|
341
|
+
digital=dn,
|
|
342
|
+
star_signal=star_signal,
|
|
343
|
+
target_signal=target_signal,
|
|
344
|
+
star_positions_per_frame=per_frame_positions,
|
|
345
|
+
num_stars=num_stars,
|
|
346
|
+
)
|