sdasim 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdasim/__init__.py ADDED
@@ -0,0 +1,90 @@
1
+ """sdasim — Speed-optimized differentiable satellite scene simulator."""
2
+
3
+ from sdasim._version import __version__
4
+ from sdasim.batch import BatchRenderResult, render_scene_batch
5
+ from sdasim.config import (
6
+ SceneConfig,
7
+ SensorConfig,
8
+ StarFieldConfig,
9
+ StarMotionConfig,
10
+ TargetConfig,
11
+ load_config,
12
+ )
13
+ from sdasim.device import get_device, resolve_device, set_device
14
+ from sdasim.empirical import (
15
+ EmpiricalNoise,
16
+ EmpiricalPSF,
17
+ render_frame_empirical,
18
+ )
19
+ from sdasim.fpa import analog_to_digital, eod_to_sigma, mv_to_pe, pe_to_mv
20
+ from sdasim.noise import gaussian_noise, poisson_noise
21
+ from sdasim.render import expand_motion, render_frame
22
+ from sdasim.scene import Scene
23
+ from sdasim.splat import (
24
+ splat_elliptical_gaussian_batched,
25
+ splat_gaussians,
26
+ splat_gaussians_batched,
27
+ splat_moffat_batched,
28
+ )
29
+
30
+
31
+ def __getattr__(name: str):
32
+ if name == "io":
33
+ import importlib
34
+
35
+ _io = importlib.import_module("sdasim.io")
36
+ globals()["io"] = _io # cache so __getattr__ isn't called again
37
+ return _io
38
+ if name in ("sampler", "SceneDistribution", "random_scene"):
39
+ import importlib
40
+
41
+ _sampler = importlib.import_module("sdasim.sampler")
42
+ globals()["sampler"] = _sampler
43
+ globals()["SceneDistribution"] = _sampler.SceneDistribution
44
+ globals()["random_scene"] = _sampler.random_scene
45
+ return globals()[name]
46
+ raise AttributeError(f"module 'sdasim' has no attribute {name!r}")
47
+
48
+
49
+ __all__ = [
50
+ "__version__",
51
+ # Device
52
+ "get_device",
53
+ "set_device",
54
+ "resolve_device",
55
+ # Core
56
+ "splat_gaussians",
57
+ "splat_gaussians_batched",
58
+ "splat_moffat_batched",
59
+ "splat_elliptical_gaussian_batched",
60
+ "poisson_noise",
61
+ "gaussian_noise",
62
+ "analog_to_digital",
63
+ "mv_to_pe",
64
+ "pe_to_mv",
65
+ "eod_to_sigma",
66
+ # Render
67
+ "render_frame",
68
+ "expand_motion",
69
+ "render_scene_batch",
70
+ "BatchRenderResult",
71
+ # Empirical (opt-in)
72
+ "EmpiricalPSF",
73
+ "EmpiricalNoise",
74
+ "render_frame_empirical",
75
+ # Scene
76
+ "Scene",
77
+ # Config
78
+ "SceneConfig",
79
+ "SensorConfig",
80
+ "StarFieldConfig",
81
+ "StarMotionConfig",
82
+ "TargetConfig",
83
+ "load_config",
84
+ # I/O (lazy)
85
+ "io",
86
+ # Sampler (lazy)
87
+ "sampler",
88
+ "SceneDistribution",
89
+ "random_scene",
90
+ ]
sdasim/_compat.py ADDED
@@ -0,0 +1,194 @@
1
+ """satsim configuration converter.
2
+
3
+ Maps satsim config dicts to sdasim SceneConfig. If the satsim config uses
4
+ $sample/$ref/$generator, calls satsim.config.loading.realize() first
5
+ (requires satsim installed). For flat configs, no satsim dependency needed.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any
11
+
12
+ from sdasim.config import (
13
+ SceneConfig,
14
+ SensorConfig,
15
+ StarFieldConfig,
16
+ StarMotionConfig,
17
+ TargetConfig,
18
+ )
19
+ from sdasim.fpa import eod_to_sigma
20
+
21
+
22
+ def _has_dynamic_keys(d: Any) -> bool:
23
+ """Check if a dict tree contains satsim dynamic keys ($sample, $ref, etc.)."""
24
+ if isinstance(d, dict):
25
+ for k, v in d.items():
26
+ if k.startswith("$"):
27
+ return True
28
+ if _has_dynamic_keys(v):
29
+ return True
30
+ elif isinstance(d, list):
31
+ return any(_has_dynamic_keys(item) for item in d)
32
+ return False
33
+
34
+
35
+ def _get(d: dict, *keys: str, default: Any = None) -> Any:
36
+ """Safely traverse nested dict."""
37
+ for key in keys:
38
+ if not isinstance(d, dict):
39
+ return default
40
+ d = d.get(key, default)
41
+ if d is default:
42
+ return default
43
+ return d
44
+
45
+
46
+ def from_satsim_config(satsim_dict: dict, seed: int | None = None) -> SceneConfig:
47
+ """Convert a satsim configuration dict to a sdasim SceneConfig.
48
+
49
+ If the config contains dynamic keys ($sample, $ref, $generator), this
50
+ function calls satsim.config.loading.realize() first.
51
+
52
+ Args:
53
+ satsim_dict: satsim configuration dictionary.
54
+ seed: Random seed for config resolution.
55
+
56
+ Returns:
57
+ SceneConfig ready for Scene construction.
58
+ """
59
+ cfg = satsim_dict
60
+
61
+ # Resolve dynamic keys if present
62
+ if _has_dynamic_keys(cfg):
63
+ try:
64
+ from satsim.config.loading import realize
65
+ except ImportError:
66
+ raise ImportError(
67
+ "satsim config contains dynamic keys ($sample/$ref/$generator) "
68
+ "but satsim is not installed. Install it or use a flat config."
69
+ )
70
+ cfg = realize(cfg, seed=seed)
71
+
72
+ # Extract FPA / sensor config
73
+ fpa = _get(cfg, "fpa", default={})
74
+ height = _get(fpa, "height", default=512)
75
+ width = _get(fpa, "width", default=512)
76
+ y_fov = _get(fpa, "y_fov", default=0.5)
77
+ x_fov = _get(fpa, "x_fov", default=0.5)
78
+
79
+ time_cfg = _get(fpa, "time", default={})
80
+ exposure = _get(time_cfg, "exposure", default=2.0)
81
+ gap = _get(time_cfg, "gap", default=0.5)
82
+
83
+ num_frames = _get(fpa, "num_frames", default=1)
84
+ zeropoint = _get(fpa, "zeropoint", default=23.5)
85
+
86
+ # PSF: convert from EOD if Gaussian PSF specified
87
+ psf_cfg = _get(fpa, "psf", default={})
88
+ osf = _get(fpa, "s_osf", default=1)
89
+ if "eod" in psf_cfg:
90
+ psf_sigma = eod_to_sigma(psf_cfg["eod"], osf=1.0) # native resolution
91
+ elif "sigma" in psf_cfg:
92
+ psf_sigma = psf_cfg["sigma"] / osf # convert from oversampled to native
93
+ else:
94
+ psf_sigma = 1.5
95
+
96
+ # Noise
97
+ noise_cfg = _get(fpa, "noise", default={})
98
+ read_noise = _get(noise_cfg, "read", default=10.0)
99
+ electronic_noise = _get(noise_cfg, "electronic", default=5.0)
100
+
101
+ # A2D
102
+ a2d_cfg = _get(fpa, "a2d", default={})
103
+ gain = _get(a2d_cfg, "gain", default=8.0)
104
+ fwc = _get(a2d_cfg, "fwc", default=100000.0)
105
+ a2d_bias = _get(a2d_cfg, "bias", default=500.0)
106
+ a2d_dtype = _get(a2d_cfg, "dtype", default="uint16")
107
+
108
+ # Dark current
109
+ dark_current = _get(fpa, "dark_current", default=10.0)
110
+
111
+ # Background
112
+ bg_cfg = _get(cfg, "background", default={})
113
+ background_mv = _get(bg_cfg, "galactic", "mv", default=21.0)
114
+ if background_mv is None:
115
+ background_mv = 21.0
116
+
117
+ # Bias
118
+ bias = _get(fpa, "bias", default=50.0)
119
+
120
+ sensor = SensorConfig(
121
+ height=height,
122
+ width=width,
123
+ y_fov=y_fov,
124
+ x_fov=x_fov,
125
+ exposure=exposure,
126
+ gap=gap,
127
+ num_frames=num_frames,
128
+ zeropoint=zeropoint,
129
+ psf_sigma=psf_sigma,
130
+ dark_current=dark_current,
131
+ read_noise=read_noise,
132
+ electronic_noise=electronic_noise,
133
+ background_mv=background_mv,
134
+ bias=bias,
135
+ gain=gain,
136
+ fwc=fwc,
137
+ a2d_bias=a2d_bias,
138
+ a2d_dtype=a2d_dtype,
139
+ )
140
+
141
+ # Stars
142
+ geom = _get(cfg, "geometry", default={})
143
+ stars_cfg = _get(geom, "stars", default={})
144
+
145
+ star_mode = _get(stars_cfg, "mode", default="bins")
146
+ if star_mode == "random" or "mv" in stars_cfg:
147
+ # Random bins mode
148
+ mv_bins = _get(stars_cfg, "mv", "bins", default=[6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
149
+ density = _get(stars_cfg, "mv", "density", default=[1.0] * (len(mv_bins) - 1))
150
+ stars = StarFieldConfig(mode="bins", mv_bins=mv_bins, density=density)
151
+ elif star_mode == "sstr7":
152
+ stars = StarFieldConfig(
153
+ mode="sstr7",
154
+ catalog_path=_get(stars_cfg, "path", default=None),
155
+ )
156
+ else:
157
+ stars = StarFieldConfig(mode="bins")
158
+
159
+ # Star motion
160
+ star_motion_cfg = _get(geom, "star_motion", default={})
161
+
162
+ rotation = _get(star_motion_cfg, "rotation", default=0.0)
163
+ translation = _get(star_motion_cfg, "translation", default=[0.0, 0.0])
164
+ t_osf = _get(cfg, "t_osf", default=_get(fpa, "t_osf", default=100))
165
+
166
+ star_motion = StarMotionConfig(
167
+ rotation=rotation,
168
+ translation=translation,
169
+ temporal_osf=t_osf,
170
+ )
171
+
172
+ # Targets (observation objects)
173
+ obs_list = _get(geom, "obs", default=[])
174
+ if not isinstance(obs_list, list):
175
+ obs_list = [obs_list]
176
+
177
+ targets = []
178
+ for obs in obs_list:
179
+ origin = _get(obs, "origin", default=[0.5, 0.5])
180
+ velocity = _get(obs, "velocity", default=[0.0, 0.0])
181
+ mv = _get(obs, "mv", default=12.0)
182
+ mode = _get(obs, "mode", default="line")
183
+ targets.append(TargetConfig(mode=mode, origin=origin, velocity=velocity, mv=mv))
184
+
185
+ return SceneConfig(
186
+ sensor=sensor,
187
+ stars=stars,
188
+ star_motion=star_motion,
189
+ targets=targets,
190
+ seed=seed,
191
+ device="auto",
192
+ enable_shot_noise=_get(noise_cfg, "photon", default=True),
193
+ enable_read_noise=read_noise > 0 or electronic_noise > 0,
194
+ )
sdasim/_version.py ADDED
@@ -0,0 +1,6 @@
1
+ try:
2
+ from importlib.metadata import version
3
+
4
+ __version__ = version("sdasim")
5
+ except Exception:
6
+ __version__ = "0.1.0.dev0"
sdasim/batch.py ADDED
@@ -0,0 +1,346 @@
1
+ """Batched multi-scene rendering.
2
+
3
+ Fuses N heterogeneous Scenes into a single kernel launch. Each scene can have
4
+ its own star catalog, its own PSF sigma, its own noise/gain/background/etc.
5
+
6
+ Restrictions in this first pass (matches what zerosda's pretraining loop
7
+ needs):
8
+ - All scenes must share (height, width) and a2d_dtype.
9
+ - Only supports frame_idx=0 and mode=None. Rate_sidereal mode-dispatch
10
+ is left to the non-batched Scene.render() path.
11
+ - Star motion (rate tracking) is supported per-scene via expand_motion.
12
+ - Targets are rendered with their own velocities, also via expand_motion.
13
+
14
+ Key win: all stars and all targets across all B scenes are splatted with two
15
+ kernel launches total (one for stars, one for targets), instead of 2*B.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import Sequence
23
+
24
+ import torch
25
+ from torch import Tensor
26
+
27
+ from sdasim.fpa import MAX_PIXEL_VALUE
28
+ from sdasim.noise import poisson_noise
29
+ from sdasim.render import expand_motion
30
+ from sdasim.scene import Scene
31
+ from sdasim.splat import splat_gaussians_batched
32
+ from sdasim.targets import compute_target_positions
33
+
34
+ _CACHE_ATTR = "_batch_source_cache"
35
+
36
+
37
+ def _ensure_source_cache(scene: Scene) -> dict:
38
+ """Lazily cache per-scene source tensors for batched rendering.
39
+
40
+ This precomputes the motion-expanded star and target source lists plus the
41
+ per-scene scalar params used downstream. All of this is deterministic given
42
+ the scene's config and doesn't change across reuses, so computing it once
43
+ at pool-fill time moves a lot of per-batch Python work out of the hot path.
44
+ """
45
+ cache = getattr(scene, _CACHE_ATTR, None)
46
+ if cache is not None:
47
+ return cache
48
+
49
+ sensor = scene.sensor
50
+ sm = scene.config.star_motion
51
+ device = scene.device
52
+ center = (sensor.height / 2.0, sensor.width / 2.0)
53
+
54
+ # --- Stars with within-frame motion blur ---
55
+ sp = scene.star_positions
56
+ si = scene.star_intensities
57
+ has_motion = sm.rotation != 0.0 or sm.translation[0] != 0.0 or sm.translation[1] != 0.0
58
+ if has_motion and sm.temporal_osf > 1 and sp.shape[0] > 0:
59
+ star_pos, star_int = expand_motion(
60
+ sp,
61
+ si,
62
+ sm.translation,
63
+ sm.rotation,
64
+ 0.0,
65
+ sensor.exposure,
66
+ sm.temporal_osf,
67
+ center,
68
+ )
69
+ else:
70
+ star_pos, star_int = sp, si
71
+
72
+ # --- Targets with per-target motion blur ---
73
+ tp, ti, tv = compute_target_positions(scene.config.targets, sensor, 0, device)
74
+ if tp.shape[0] > 0:
75
+ max_speed = float(tv.abs().max().item()) if tv.numel() else 0.0
76
+ streak_px = max_speed * sensor.exposure
77
+ tgt_osf = max(1, int(streak_px * 2))
78
+ if tgt_osf > 1:
79
+ tgt_pos, tgt_int = expand_motion(
80
+ tp,
81
+ ti,
82
+ tv,
83
+ 0.0,
84
+ 0.0,
85
+ sensor.exposure,
86
+ tgt_osf,
87
+ center,
88
+ )
89
+ else:
90
+ tgt_pos, tgt_int = tp, ti
91
+ else:
92
+ tgt_pos = torch.zeros(0, 2, device=device, dtype=torch.float32)
93
+ tgt_int = torch.zeros(0, device=device, dtype=torch.float32)
94
+
95
+ cache = {
96
+ "star_pos": star_pos,
97
+ "star_int": star_int,
98
+ "star_count": int(star_pos.shape[0]),
99
+ "tgt_pos": tgt_pos,
100
+ "tgt_int": tgt_int,
101
+ "tgt_count": int(tgt_pos.shape[0]),
102
+ "psf_sigma": float(sensor.psf_sigma),
103
+ "background_pe": float(scene.background_pe),
104
+ "dark_current_pe": float(scene.dark_current_pe),
105
+ "bias_pe": float(scene.bias_pe),
106
+ "read_noise": float(sensor.read_noise),
107
+ "electronic_noise": float(sensor.electronic_noise),
108
+ "gain": float(sensor.gain),
109
+ "fwc": float(sensor.fwc),
110
+ "a2d_bias": float(sensor.a2d_bias),
111
+ }
112
+ setattr(scene, _CACHE_ATTR, cache)
113
+ return cache
114
+
115
+
116
+ @dataclass
117
+ class BatchRenderResult:
118
+ """Output of render_scene_batch.
119
+
120
+ Attributes:
121
+ digital: [B, H, W] integer-ADU digital image.
122
+ star_signal: [B, H, W] pre-noise star-only PE.
123
+ target_signal: [B, H, W] pre-noise target-only PE.
124
+ star_positions_per_frame: list of [N_b, 2] per-frame pre-expansion star
125
+ positions (same as each scene's Scene.star_positions). Useful for
126
+ generating training heatmaps without running the splat again.
127
+ num_stars: [B] int counts of sources in each frame (pre-motion-expand).
128
+ """
129
+
130
+ digital: Tensor
131
+ star_signal: Tensor
132
+ target_signal: Tensor
133
+ star_positions_per_frame: list[Tensor]
134
+ num_stars: Tensor
135
+
136
+
137
+ def _collect_cached(
138
+ scenes: Sequence[Scene],
139
+ device: torch.device,
140
+ kind: str,
141
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
142
+ """Gather pre-computed (via _ensure_source_cache) star or target sources.
143
+
144
+ This collects already-cached per-scene tensors and builds (positions,
145
+ intensities, frame_ids, per_source_sigma) via a single torch.cat +
146
+ repeat_interleave pair. Extremely cheap because the motion expansion
147
+ work was done once at cache-fill time.
148
+ """
149
+ pos_key = f"{kind}_pos"
150
+ int_key = f"{kind}_int"
151
+ cnt_key = f"{kind}_count"
152
+
153
+ pos_chunks: list[Tensor] = []
154
+ int_chunks: list[Tensor] = []
155
+ counts: list[int] = []
156
+ sigmas: list[float] = []
157
+ fids: list[int] = []
158
+
159
+ for i, scene in enumerate(scenes):
160
+ c = _ensure_source_cache(scene)
161
+ n = c[cnt_key]
162
+ if n == 0:
163
+ continue
164
+ pos_chunks.append(c[pos_key])
165
+ int_chunks.append(c[int_key])
166
+ counts.append(n)
167
+ sigmas.append(c["psf_sigma"])
168
+ fids.append(i)
169
+
170
+ if not pos_chunks:
171
+ z2 = torch.zeros(0, 2, dtype=torch.float32, device=device)
172
+ z1 = torch.zeros(0, dtype=torch.float32, device=device)
173
+ zi = torch.zeros(0, dtype=torch.long, device=device)
174
+ return z2, z1, zi, z1
175
+
176
+ positions = torch.cat(pos_chunks, dim=0)
177
+ intensities = torch.cat(int_chunks, dim=0)
178
+ counts_t = torch.tensor(counts, dtype=torch.long, device=device)
179
+ frame_ids = torch.repeat_interleave(
180
+ torch.tensor(fids, dtype=torch.long, device=device), counts_t
181
+ )
182
+ per_source_sigma = torch.repeat_interleave(
183
+ torch.tensor(sigmas, dtype=torch.float32, device=device), counts_t
184
+ )
185
+ return positions, intensities, frame_ids, per_source_sigma
186
+
187
+
188
+ def _render_batch_empirical(scenes: Sequence[Scene]) -> BatchRenderResult:
189
+ """Empirical-mode batch: each scene has its own sampled PSF kernel + per-scene FFT, so it
190
+ can't fuse into a single splat like the Gaussian path. Render each via the full empirical
191
+ pipeline (Scene.render_signals) and stack. FFTs still run on-GPU per scene."""
192
+ device = scenes[0].device
193
+ digs, ss, ts, pf = [], [], [], []
194
+ for s in scenes:
195
+ d, star_sig, tgt_sig, _ = s.render_signals(0)
196
+ digs.append(d)
197
+ ss.append(star_sig)
198
+ ts.append(tgt_sig)
199
+ pf.append(s.star_positions)
200
+ num_stars = torch.tensor([p.shape[0] for p in pf], dtype=torch.long, device=device)
201
+ return BatchRenderResult(torch.stack(digs), torch.stack(ss), torch.stack(ts), pf, num_stars)
202
+
203
+
204
+ def render_scene_batch(scenes: Sequence[Scene]) -> BatchRenderResult:
205
+ """Render N scenes in a single fused pass.
206
+
207
+ Each scene is treated as its own telescope/exposure with its own PSF,
208
+ background, read noise, gain, etc. All scenes contribute to one
209
+ (B, H, W) output via per-source frame_id tagging on the splats.
210
+
211
+ Args:
212
+ scenes: Sequence of Scene objects. All must share (height, width,
213
+ a2d_dtype, device). Scene.config.mode must be None for each scene.
214
+
215
+ Returns:
216
+ BatchRenderResult with digital, star_signal, target_signal, and the
217
+ per-frame pre-expansion star positions (handy for building heatmaps).
218
+ """
219
+ if len(scenes) == 0:
220
+ raise ValueError("render_scene_batch requires at least one scene")
221
+
222
+ ref = scenes[0]
223
+ device = ref.device
224
+ H = ref.sensor.height
225
+ W = ref.sensor.width
226
+ a2d_dtype = ref.sensor.a2d_dtype
227
+ B = len(scenes)
228
+
229
+ for s in scenes:
230
+ if s.sensor.height != H or s.sensor.width != W:
231
+ raise ValueError(
232
+ "render_scene_batch requires all scenes to share (height, width); "
233
+ f"got {(s.sensor.height, s.sensor.width)} vs {(H, W)}"
234
+ )
235
+ if s.sensor.a2d_dtype != a2d_dtype:
236
+ raise ValueError("render_scene_batch requires all scenes to share a2d_dtype")
237
+ if s.config.mode is not None:
238
+ raise ValueError(
239
+ "render_scene_batch only supports mode=None (sidereal/rate_track); "
240
+ f"got mode={s.config.mode}"
241
+ )
242
+
243
+ # Empirical scenes use per-scene sampled kernels + FFT -> render per-scene and stack.
244
+ if any(
245
+ s.sensor.psf_model == "empirical" or s.sensor.noise_model == "empirical" for s in scenes
246
+ ):
247
+ return _render_batch_empirical(scenes)
248
+
249
+ # Ensure caches exist (lazy, idempotent on pool scenes).
250
+ for s in scenes:
251
+ _ensure_source_cache(s)
252
+
253
+ per_frame_positions = [s.star_positions for s in scenes]
254
+ star_pos, star_int, star_fid, star_sig = _collect_cached(scenes, device, "star")
255
+ tgt_pos, tgt_int, tgt_fid, tgt_sig = _collect_cached(scenes, device, "tgt")
256
+
257
+ # Single fused splat per kind. splat_gaussians_batched handles the empty
258
+ # source case by returning zeros of the right shape.
259
+ star_signal = splat_gaussians_batched(
260
+ B,
261
+ H,
262
+ W,
263
+ star_pos,
264
+ star_int,
265
+ star_fid,
266
+ star_sig,
267
+ )
268
+ target_signal = splat_gaussians_batched(
269
+ B,
270
+ H,
271
+ W,
272
+ tgt_pos,
273
+ tgt_int,
274
+ tgt_fid,
275
+ tgt_sig,
276
+ )
277
+
278
+ signal = star_signal + target_signal
279
+
280
+ # Per-frame scalar params from cache. Build a single tensor per param.
281
+ caches = [getattr(s, _CACHE_ATTR) for s in scenes]
282
+ bg = torch.tensor(
283
+ [c["background_pe"] for c in caches],
284
+ dtype=torch.float32,
285
+ device=device,
286
+ ).view(B, 1, 1)
287
+ dc = torch.tensor(
288
+ [c["dark_current_pe"] for c in caches],
289
+ dtype=torch.float32,
290
+ device=device,
291
+ ).view(B, 1, 1)
292
+ bias = torch.tensor(
293
+ [c["bias_pe"] for c in caches],
294
+ dtype=torch.float32,
295
+ device=device,
296
+ ).view(B, 1, 1)
297
+
298
+ signal = signal + bg + dc + bias
299
+
300
+ enable_shot = all(s.config.enable_shot_noise for s in scenes)
301
+ enable_read = all(s.config.enable_read_noise for s in scenes)
302
+ if enable_shot:
303
+ signal = poisson_noise(signal)
304
+ if enable_read:
305
+ rn_sigma = torch.tensor(
306
+ [math.sqrt(c["read_noise"] ** 2 + c["electronic_noise"] ** 2) for c in caches],
307
+ dtype=torch.float32,
308
+ device=device,
309
+ ).view(B, 1, 1)
310
+ signal = signal + rn_sigma * torch.randn_like(signal)
311
+
312
+ a2d_bias_t = torch.tensor(
313
+ [c["a2d_bias"] for c in caches],
314
+ dtype=torch.float32,
315
+ device=device,
316
+ ).view(B, 1, 1)
317
+ fwc_t = torch.tensor(
318
+ [c["fwc"] for c in caches],
319
+ dtype=torch.float32,
320
+ device=device,
321
+ ).view(B, 1, 1)
322
+ gain_t = torch.tensor(
323
+ [c["gain"] for c in caches],
324
+ dtype=torch.float32,
325
+ device=device,
326
+ ).view(B, 1, 1)
327
+
328
+ biased = (signal + a2d_bias_t).clamp(min=0.0)
329
+ biased = torch.minimum(biased, fwc_t)
330
+ dn = torch.floor(biased / gain_t)
331
+ max_val = MAX_PIXEL_VALUE.get(a2d_dtype, 65535.0)
332
+ dn = dn.clamp(min=0.0, max=max_val)
333
+
334
+ num_stars = torch.tensor(
335
+ [p.shape[0] for p in per_frame_positions],
336
+ dtype=torch.long,
337
+ device=device,
338
+ )
339
+
340
+ return BatchRenderResult(
341
+ digital=dn,
342
+ star_signal=star_signal,
343
+ target_signal=target_signal,
344
+ star_positions_per_frame=per_frame_positions,
345
+ num_stars=num_stars,
346
+ )