torchrir 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/core.py ADDED
@@ -0,0 +1,741 @@
1
+ from __future__ import annotations
2
+
3
+ """Core RIR simulation functions (static and dynamic)."""
4
+
5
+ import math
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ from .config import SimulationConfig, default_config
12
+ from .directivity import directivity_gain, split_directivity
13
+ from .room import MicrophoneArray, Room, Source
14
+ from .utils import (
15
+ as_tensor,
16
+ ensure_dim,
17
+ estimate_beta_from_t60,
18
+ estimate_t60_from_beta,
19
+ infer_device_dtype,
20
+ normalize_orientation,
21
+ orientation_to_unit,
22
+ resolve_device,
23
+ )
24
+
25
+
26
+ def simulate_rir(
27
+ *,
28
+ room: Room,
29
+ sources: Source | Tensor,
30
+ mics: MicrophoneArray | Tensor,
31
+ max_order: int | None,
32
+ nb_img: Optional[Tensor | Tuple[int, ...]] = None,
33
+ nsample: Optional[int] = None,
34
+ tmax: Optional[float] = None,
35
+ tdiff: Optional[float] = None,
36
+ directivity: str | tuple[str, str] | None = "omni",
37
+ orientation: Optional[Tensor | tuple[Tensor, Tensor]] = None,
38
+ config: Optional[SimulationConfig] = None,
39
+ device: Optional[torch.device | str] = None,
40
+ dtype: Optional[torch.dtype] = None,
41
+ ) -> Tensor:
42
+ """Simulate a static RIR using the image source method.
43
+
44
+ Args:
45
+ room: Room configuration (geometry, fs, reflection coefficients).
46
+ sources: Source positions or a Source object.
47
+ mics: Microphone positions or a MicrophoneArray object.
48
+ max_order: Maximum reflection order (uses config if None).
49
+ nb_img: Optional per-dimension image counts (overrides max_order).
50
+ nsample: Output length in samples.
51
+ tmax: Output length in seconds (used if nsample is None).
52
+ tdiff: Optional time to start diffuse tail modeling.
53
+ directivity: Directivity pattern(s) for source and mic (uses config if None).
54
+ orientation: Orientation vectors or angles.
55
+ config: Optional simulation configuration overrides.
56
+ device: Output device.
57
+ dtype: Output dtype.
58
+
59
+ Returns:
60
+ Tensor of shape (n_src, n_mic, nsample).
61
+ """
62
+ cfg = config or default_config()
63
+ cfg.validate()
64
+
65
+ if device is None and cfg.device is not None:
66
+ device = cfg.device
67
+
68
+ if max_order is None:
69
+ if cfg.max_order is None:
70
+ raise ValueError("max_order must be provided if not set in config")
71
+ max_order = cfg.max_order
72
+
73
+ if tmax is None and nsample is None and cfg.tmax is not None:
74
+ tmax = cfg.tmax
75
+
76
+ if directivity is None:
77
+ directivity = cfg.directivity or "omni"
78
+
79
+ if not isinstance(room, Room):
80
+ raise TypeError("room must be a Room instance")
81
+ if nsample is None and tmax is None:
82
+ raise ValueError("nsample or tmax must be provided")
83
+ if nsample is None:
84
+ nsample = int(math.ceil(tmax * room.fs))
85
+ if nsample <= 0:
86
+ raise ValueError("nsample must be positive")
87
+ if max_order < 0:
88
+ raise ValueError("max_order must be non-negative")
89
+
90
+ if isinstance(device, str):
91
+ device = resolve_device(device)
92
+
93
+ src_pos, src_ori = _prepare_entities(
94
+ sources, orientation, which="source", device=device, dtype=dtype
95
+ )
96
+ mic_pos, mic_ori = _prepare_entities(
97
+ mics, orientation, which="mic", device=device, dtype=dtype
98
+ )
99
+
100
+ device, dtype = infer_device_dtype(
101
+ src_pos, mic_pos, room.size, device=device, dtype=dtype
102
+ )
103
+ src_pos = as_tensor(src_pos, device=device, dtype=dtype)
104
+ mic_pos = as_tensor(mic_pos, device=device, dtype=dtype)
105
+
106
+ if src_ori is not None:
107
+ src_ori = as_tensor(src_ori, device=device, dtype=dtype)
108
+ if mic_ori is not None:
109
+ mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
110
+
111
+ room_size = as_tensor(room.size, device=device, dtype=dtype)
112
+ room_size = ensure_dim(room_size)
113
+ dim = room_size.numel()
114
+
115
+ if src_pos.ndim == 1:
116
+ src_pos = src_pos.unsqueeze(0)
117
+ if mic_pos.ndim == 1:
118
+ mic_pos = mic_pos.unsqueeze(0)
119
+ if src_pos.ndim != 2 or src_pos.shape[1] != dim:
120
+ raise ValueError("sources must be of shape (n_src, dim)")
121
+ if mic_pos.ndim != 2 or mic_pos.shape[1] != dim:
122
+ raise ValueError("mics must be of shape (n_mic, dim)")
123
+
124
+ beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
125
+ beta = _validate_beta(beta, dim)
126
+
127
+ n_vec = _image_source_indices(max_order, dim, device=device, nb_img=nb_img)
128
+ refl = _reflection_coefficients(n_vec, beta)
129
+
130
+ src_pattern, mic_pattern = split_directivity(directivity)
131
+ mic_dir = None
132
+ if mic_pattern != "omni":
133
+ if mic_ori is None:
134
+ raise ValueError("mic orientation required for non-omni directivity")
135
+ mic_dir = orientation_to_unit(mic_ori, dim)
136
+
137
+ n_src = src_pos.shape[0]
138
+ n_mic = mic_pos.shape[0]
139
+ rir = torch.zeros((n_src, n_mic, nsample), device=device, dtype=dtype)
140
+ fdl = cfg.frac_delay_length
141
+ fdl2 = (fdl - 1) // 2
142
+ img_chunk = cfg.image_chunk_size
143
+ if img_chunk <= 0:
144
+ img_chunk = n_vec.shape[0]
145
+
146
+ src_dirs = None
147
+ if src_pattern != "omni":
148
+ if src_ori is None:
149
+ raise ValueError("source orientation required for non-omni directivity")
150
+ src_dirs = orientation_to_unit(src_ori, dim)
151
+ if src_dirs.ndim == 1:
152
+ src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
153
+ if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
154
+ raise ValueError("source orientation must match number of sources")
155
+
156
+ for start in range(0, n_vec.shape[0], img_chunk):
157
+ end = min(start + img_chunk, n_vec.shape[0])
158
+ n_vec_chunk = n_vec[start:end]
159
+ refl_chunk = refl[start:end]
160
+ sample_chunk, attenuation_chunk = _compute_image_contributions_batch(
161
+ src_pos,
162
+ mic_pos,
163
+ room_size,
164
+ n_vec_chunk,
165
+ refl_chunk,
166
+ room,
167
+ fdl2,
168
+ src_pattern=src_pattern,
169
+ mic_pattern=mic_pattern,
170
+ src_dirs=src_dirs,
171
+ mic_dir=mic_dir,
172
+ )
173
+ _accumulate_rir_batch(rir, sample_chunk, attenuation_chunk, cfg)
174
+
175
+ if tdiff is not None and tmax is not None and tdiff < tmax:
176
+ rir = _apply_diffuse_tail(rir, room, beta, tdiff, tmax, seed=cfg.seed)
177
+ return rir
178
+
179
+
180
+ def simulate_dynamic_rir(
181
+ *,
182
+ room: Room,
183
+ src_traj: Tensor,
184
+ mic_traj: Tensor,
185
+ max_order: int | None,
186
+ nsample: Optional[int] = None,
187
+ tmax: Optional[float] = None,
188
+ directivity: str | tuple[str, str] | None = "omni",
189
+ orientation: Optional[Tensor | tuple[Tensor, Tensor]] = None,
190
+ config: Optional[SimulationConfig] = None,
191
+ device: Optional[torch.device | str] = None,
192
+ dtype: Optional[torch.dtype] = None,
193
+ ) -> Tensor:
194
+ """Simulate time-varying RIRs for source/mic trajectories.
195
+
196
+ Args:
197
+ room: Room configuration.
198
+ src_traj: Source trajectory (T, n_src, dim).
199
+ mic_traj: Microphone trajectory (T, n_mic, dim).
200
+ max_order: Maximum reflection order (uses config if None).
201
+ nsample: Output length in samples.
202
+ tmax: Output length in seconds (used if nsample is None).
203
+ directivity: Directivity pattern(s) for source and mic (uses config if None).
204
+ orientation: Orientation vectors or angles.
205
+ config: Optional simulation configuration overrides.
206
+ device: Output device.
207
+ dtype: Output dtype.
208
+
209
+ Returns:
210
+ Tensor of shape (T, n_src, n_mic, nsample).
211
+ """
212
+ cfg = config or default_config()
213
+ cfg.validate()
214
+
215
+ if device is None and cfg.device is not None:
216
+ device = cfg.device
217
+
218
+ if max_order is None:
219
+ if cfg.max_order is None:
220
+ raise ValueError("max_order must be provided if not set in config")
221
+ max_order = cfg.max_order
222
+
223
+ if tmax is None and nsample is None and cfg.tmax is not None:
224
+ tmax = cfg.tmax
225
+
226
+ if directivity is None:
227
+ directivity = cfg.directivity or "omni"
228
+
229
+ if isinstance(device, str):
230
+ device = resolve_device(device)
231
+
232
+ src_traj = as_tensor(src_traj, device=device, dtype=dtype)
233
+ mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
234
+
235
+ if src_traj.ndim == 2:
236
+ src_traj = src_traj.unsqueeze(1)
237
+ if mic_traj.ndim == 2:
238
+ mic_traj = mic_traj.unsqueeze(1)
239
+ if src_traj.ndim != 3:
240
+ raise ValueError("src_traj must be of shape (T, n_src, dim)")
241
+ if mic_traj.ndim != 3:
242
+ raise ValueError("mic_traj must be of shape (T, n_mic, dim)")
243
+ if src_traj.shape[0] != mic_traj.shape[0]:
244
+ raise ValueError("src_traj and mic_traj must have the same time length")
245
+
246
+ t_steps = src_traj.shape[0]
247
+ rirs = []
248
+ for t_idx in range(t_steps):
249
+ rir = simulate_rir(
250
+ room=room,
251
+ sources=src_traj[t_idx],
252
+ mics=mic_traj[t_idx],
253
+ max_order=max_order,
254
+ nsample=nsample,
255
+ tmax=tmax,
256
+ directivity=directivity,
257
+ orientation=orientation,
258
+ config=config,
259
+ device=device,
260
+ dtype=dtype,
261
+ )
262
+ rirs.append(rir)
263
+ return torch.stack(rirs, dim=0)
264
+
265
+
266
+ def _prepare_entities(
267
+ entities: Source | MicrophoneArray | Tensor,
268
+ orientation: Optional[Tensor | tuple[Tensor, Tensor]],
269
+ *,
270
+ which: str,
271
+ device: Optional[torch.device | str],
272
+ dtype: Optional[torch.dtype],
273
+ ) -> Tuple[Tensor, Optional[Tensor]]:
274
+ """Extract positions and orientations from entities or raw tensors.
275
+
276
+ Returns:
277
+ Tuple of (positions, orientation).
278
+ """
279
+ if isinstance(entities, (Source, MicrophoneArray)):
280
+ pos = entities.positions
281
+ ori = entities.orientation
282
+ else:
283
+ pos = entities
284
+ ori = None
285
+ if orientation is not None:
286
+ if isinstance(orientation, (list, tuple)):
287
+ if len(orientation) != 2:
288
+ raise ValueError("orientation tuple must have length 2")
289
+ ori = orientation[0] if which == "source" else orientation[1]
290
+ else:
291
+ ori = orientation
292
+ pos = as_tensor(pos, device=device, dtype=dtype)
293
+ if ori is not None:
294
+ ori = as_tensor(ori, device=device, dtype=dtype)
295
+ return pos, ori
296
+
297
+
298
+ def _resolve_beta(
299
+ room: Room, room_size: Tensor, *, device: torch.device, dtype: torch.dtype
300
+ ) -> Tensor:
301
+ """Resolve reflection coefficients from beta/t60/defaults.
302
+
303
+ Returns:
304
+ Tensor of reflection coefficients per wall.
305
+ """
306
+ if room.beta is not None:
307
+ return as_tensor(room.beta, device=device, dtype=dtype)
308
+ if room.t60 is not None:
309
+ return estimate_beta_from_t60(room_size, room.t60, device=device, dtype=dtype)
310
+ dim = room_size.numel()
311
+ default_faces = 4 if dim == 2 else 6
312
+ return torch.ones((default_faces,), device=device, dtype=dtype)
313
+
314
+
315
+ def _validate_beta(beta: Tensor, dim: int) -> Tensor:
316
+ """Validate beta size against room dimension.
317
+
318
+ Returns:
319
+ The validated beta tensor.
320
+ """
321
+ expected = 4 if dim == 2 else 6
322
+ if beta.numel() != expected:
323
+ raise ValueError(f"beta must have {expected} elements for {dim}D")
324
+ return beta
325
+
326
+
327
+ def _image_source_indices(
328
+ max_order: int,
329
+ dim: int,
330
+ *,
331
+ device: torch.device,
332
+ nb_img: Optional[Tensor | Tuple[int, ...]] = None,
333
+ ) -> Tensor:
334
+ """Generate image source index vectors up to the given order.
335
+
336
+ Returns:
337
+ Tensor of shape (n_images, dim).
338
+ """
339
+ if nb_img is not None:
340
+ nb = as_tensor(nb_img, device=device, dtype=torch.int64)
341
+ if nb.numel() != dim:
342
+ raise ValueError("nb_img must match room dimension")
343
+ ranges = [torch.arange(-n, n + 1, device=device, dtype=torch.int64) for n in nb]
344
+ grids = torch.meshgrid(*ranges, indexing="ij")
345
+ return torch.stack([g.reshape(-1) for g in grids], dim=-1)
346
+ rng = torch.arange(-max_order, max_order + 1, device=device, dtype=torch.int64)
347
+ grids = torch.meshgrid(*([rng] * dim), indexing="ij")
348
+ n_vec = torch.stack([g.reshape(-1) for g in grids], dim=-1)
349
+ order = torch.sum(torch.abs(n_vec), dim=-1)
350
+ return n_vec[order <= max_order]
351
+
352
+
353
+ def _image_positions(src: Tensor, room_size: Tensor, n_vec: Tensor) -> Tensor:
354
+ """Compute image source positions for a given source.
355
+
356
+ Returns:
357
+ Tensor of image positions (n_images, dim).
358
+ """
359
+ n_vec_f = n_vec.to(dtype=src.dtype)
360
+ sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src.dtype)
361
+ n = torch.floor_divide(n_vec + 1, 2).to(dtype=src.dtype)
362
+ return 2.0 * room_size * n + sign * src
363
+
364
+
365
+ def _image_positions_batch(src_pos: Tensor, room_size: Tensor, n_vec: Tensor) -> Tensor:
366
+ """Compute image source positions for multiple sources."""
367
+ sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_pos.dtype)
368
+ n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_pos.dtype)
369
+ base = 2.0 * room_size * n
370
+ return base[None, :, :] + sign[None, :, :] * src_pos[:, None, :]
371
+
372
+
373
+ def _reflection_coefficients(n_vec: Tensor, beta: Tensor) -> Tensor:
374
+ """Compute reflection coefficients for each image source.
375
+
376
+ Returns:
377
+ Tensor of shape (n_images,) with per-image gains.
378
+ """
379
+ dim = n_vec.shape[1]
380
+ beta = beta.view(dim, 2)
381
+ beta_lo = beta[:, 0]
382
+ beta_hi = beta[:, 1]
383
+
384
+ n = n_vec
385
+ k = torch.abs(n)
386
+ n_hi = torch.where(n >= 0, (n + 1) // 2, k // 2)
387
+ n_lo = torch.where(n >= 0, n // 2, (k + 1) // 2)
388
+
389
+ n_hi = n_hi.to(dtype=beta.dtype)
390
+ n_lo = n_lo.to(dtype=beta.dtype)
391
+
392
+ coeff = (beta_hi**n_hi) * (beta_lo**n_lo)
393
+ return torch.prod(coeff, dim=1)
394
+
395
+
396
+ def _compute_image_contributions(
397
+ src: Tensor,
398
+ mic_pos: Tensor,
399
+ room_size: Tensor,
400
+ n_vec: Tensor,
401
+ refl: Tensor,
402
+ room: Room,
403
+ fdl2: int,
404
+ *,
405
+ src_pattern: str,
406
+ mic_pattern: str,
407
+ src_dir: Optional[Tensor],
408
+ mic_dir: Optional[Tensor],
409
+ ) -> Tuple[Tensor, Tensor]:
410
+ """Compute sample positions and attenuation for a source and all mics."""
411
+ img = _image_positions(src, room_size, n_vec)
412
+ vec = mic_pos[:, None, :] - img[None, :, :]
413
+ dist = torch.linalg.norm(vec, dim=-1)
414
+ dist = torch.clamp(dist, min=1e-6)
415
+ time = dist / room.c
416
+ time = time + (fdl2 / room.fs)
417
+ sample = time * room.fs
418
+
419
+ gain = refl[None, :]
420
+ if src_pattern != "omni":
421
+ if src_dir is None:
422
+ raise ValueError("source orientation required for non-omni directivity")
423
+ cos_theta = _cos_between(vec, src_dir)
424
+ gain = gain * directivity_gain(src_pattern, cos_theta)
425
+ if mic_pattern != "omni":
426
+ if mic_dir is None:
427
+ raise ValueError("mic orientation required for non-omni directivity")
428
+ cos_theta = _cos_between(-vec, mic_dir)
429
+ gain = gain * directivity_gain(mic_pattern, cos_theta)
430
+
431
+ attenuation = gain / dist
432
+ return sample, attenuation
433
+
434
+
435
+ def _compute_image_contributions_batch(
436
+ src_pos: Tensor,
437
+ mic_pos: Tensor,
438
+ room_size: Tensor,
439
+ n_vec: Tensor,
440
+ refl: Tensor,
441
+ room: Room,
442
+ fdl2: int,
443
+ *,
444
+ src_pattern: str,
445
+ mic_pattern: str,
446
+ src_dirs: Optional[Tensor],
447
+ mic_dir: Optional[Tensor],
448
+ ) -> Tuple[Tensor, Tensor]:
449
+ """Compute samples/attenuation for all sources/mics/images in batch."""
450
+ img = _image_positions_batch(src_pos, room_size, n_vec)
451
+ vec = mic_pos[None, :, None, :] - img[:, None, :, :]
452
+ dist = torch.linalg.norm(vec, dim=-1)
453
+ dist = torch.clamp(dist, min=1e-6)
454
+ time = dist / room.c
455
+ time = time + (fdl2 / room.fs)
456
+ sample = time * room.fs
457
+
458
+ gain = refl.view(1, 1, -1)
459
+ if src_pattern != "omni":
460
+ if src_dirs is None:
461
+ raise ValueError("source orientation required for non-omni directivity")
462
+ src_dirs = src_dirs[:, None, None, :]
463
+ cos_theta = _cos_between(vec, src_dirs)
464
+ gain = gain * directivity_gain(src_pattern, cos_theta)
465
+ if mic_pattern != "omni":
466
+ if mic_dir is None:
467
+ raise ValueError("mic orientation required for non-omni directivity")
468
+ mic_dir = mic_dir[None, :, None, :] if mic_dir.ndim == 2 else mic_dir.view(1, 1, 1, -1)
469
+ cos_theta = _cos_between(-vec, mic_dir)
470
+ gain = gain * directivity_gain(mic_pattern, cos_theta)
471
+
472
+ attenuation = gain / dist
473
+ return sample, attenuation
474
+
475
+
476
+ def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
477
+ """Pick the correct orientation vector for a given entity index."""
478
+ if orientation.ndim == 0:
479
+ return orientation_to_unit(orientation, dim)
480
+ if orientation.ndim == 1:
481
+ return orientation_to_unit(orientation, dim)
482
+ if orientation.ndim == 2 and orientation.shape[0] == count:
483
+ return orientation_to_unit(orientation[idx], dim)
484
+ raise ValueError("orientation must be shape (dim,), (count, dim), or angles")
485
+
486
+
487
+ def _cos_between(vec: Tensor, orientation: Tensor) -> Tensor:
488
+ """Compute cosine between direction vectors and orientation."""
489
+ orientation = normalize_orientation(orientation)
490
+ unit = vec / torch.linalg.norm(vec, dim=-1, keepdim=True)
491
+ return torch.sum(unit * orientation, dim=-1)
492
+
493
+
494
+ def _accumulate_rir(
495
+ rir: Tensor, sample: Tensor, amplitude: Tensor, cfg: SimulationConfig
496
+ ) -> None:
497
+ """Accumulate fractional-delay contributions into the RIR tensor."""
498
+ idx0 = torch.floor(sample).to(torch.int64)
499
+ frac = sample - idx0.to(sample.dtype)
500
+
501
+ n_mic, nsample = rir.shape
502
+ fdl = cfg.frac_delay_length
503
+ lut_gran = cfg.sinc_lut_granularity
504
+ use_lut = cfg.use_lut and rir.device.type != "mps"
505
+ fdl2 = (fdl - 1) // 2
506
+
507
+ dtype = amplitude.dtype
508
+ n = _get_fdl_grid(fdl, device=rir.device, dtype=dtype)
509
+ offsets = _get_fdl_offsets(fdl, device=rir.device)
510
+ window = _get_fdl_window(fdl, device=rir.device, dtype=dtype)
511
+
512
+ if use_lut:
513
+ sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
514
+
515
+ mic_offsets = (torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample).view(
516
+ n_mic, 1, 1
517
+ )
518
+ rir_flat = rir.view(-1)
519
+
520
+ chunk_size = cfg.accumulate_chunk_size
521
+ n_img = idx0.shape[1]
522
+ for start in range(0, n_img, chunk_size):
523
+ end = min(start + chunk_size, n_img)
524
+ idx = idx0[:, start:end]
525
+ amp = amplitude[:, start:end]
526
+ frac_m = frac[:, start:end]
527
+
528
+ if use_lut:
529
+ x_off_frac = (1.0 - frac_m) * lut_gran
530
+ lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
531
+ x_off = x_off_frac - lut_gran_off.to(dtype)
532
+ lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
533
+
534
+ s0 = torch.take(sinc_lut, lut_pos)
535
+ s1 = torch.take(sinc_lut, lut_pos + 1)
536
+ interp = s0 + x_off[..., None] * (s1 - s0)
537
+ filt = interp * window[None, None, :]
538
+ else:
539
+ t = n[None, None, :] - fdl2 - frac_m[..., None]
540
+ filt = torch.sinc(t) * window[None, None, :]
541
+
542
+ contrib = amp[..., None] * filt
543
+ target = idx[..., None] + offsets[None, None, :]
544
+ valid = (target >= 0) & (target < nsample)
545
+ if not valid.any():
546
+ continue
547
+
548
+ target = target + mic_offsets
549
+ target_flat = target[valid].to(torch.int64)
550
+ values_flat = contrib[valid]
551
+ rir_flat.scatter_add_(0, target_flat, values_flat)
552
+
553
+
554
+ def _accumulate_rir_batch(
555
+ rir: Tensor, sample: Tensor, amplitude: Tensor, cfg: SimulationConfig
556
+ ) -> None:
557
+ """Accumulate fractional-delay contributions for all sources/mics."""
558
+ fn = _get_accumulate_fn(cfg, rir.device, amplitude.dtype)
559
+ return fn(rir, sample, amplitude)
560
+
561
+
562
+ def _accumulate_rir_batch_impl(
563
+ rir: Tensor,
564
+ sample: Tensor,
565
+ amplitude: Tensor,
566
+ *,
567
+ fdl: int,
568
+ lut_gran: int,
569
+ use_lut: bool,
570
+ chunk_size: int,
571
+ ) -> None:
572
+ """Implementation for batch accumulation (optionally compiled)."""
573
+ idx0 = torch.floor(sample).to(torch.int64)
574
+ frac = sample - idx0.to(sample.dtype)
575
+
576
+ n_src, n_mic, nsample = rir.shape
577
+ n_sm = n_src * n_mic
578
+ idx0 = idx0.view(n_sm, -1)
579
+ frac = frac.view(n_sm, -1)
580
+ amplitude = amplitude.view(n_sm, -1)
581
+
582
+ fdl2 = (fdl - 1) // 2
583
+
584
+ n = _get_fdl_grid(fdl, device=rir.device, dtype=sample.dtype)
585
+ offsets = _get_fdl_offsets(fdl, device=rir.device)
586
+ window = _get_fdl_window(fdl, device=rir.device, dtype=sample.dtype)
587
+
588
+ if use_lut:
589
+ sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
590
+
591
+ sm_offsets = (torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample).view(
592
+ n_sm, 1, 1
593
+ )
594
+ rir_flat = rir.view(-1)
595
+
596
+ n_img = idx0.shape[1]
597
+ for start in range(0, n_img, chunk_size):
598
+ end = min(start + chunk_size, n_img)
599
+ idx = idx0[:, start:end]
600
+ amp = amplitude[:, start:end]
601
+ frac_m = frac[:, start:end]
602
+
603
+ if use_lut:
604
+ x_off_frac = (1.0 - frac_m) * lut_gran
605
+ lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
606
+ x_off = x_off_frac - lut_gran_off.to(sample.dtype)
607
+ lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
608
+
609
+ s0 = torch.take(sinc_lut, lut_pos)
610
+ s1 = torch.take(sinc_lut, lut_pos + 1)
611
+ interp = s0 + x_off[..., None] * (s1 - s0)
612
+ filt = interp * window[None, None, :]
613
+ else:
614
+ t = n[None, None, :] - fdl2 - frac_m[..., None]
615
+ filt = torch.sinc(t) * window[None, None, :]
616
+
617
+ contrib = amp[..., None] * filt
618
+ target = idx[..., None] + offsets[None, None, :]
619
+ valid = (target >= 0) & (target < nsample)
620
+ if not valid.any():
621
+ continue
622
+
623
+ target = target + sm_offsets
624
+ target_flat = target[valid].to(torch.int64)
625
+ values_flat = contrib[valid]
626
+ rir_flat.scatter_add_(0, target_flat, values_flat)
627
+
628
+
629
+ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
630
+ _FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
631
+ _FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
632
+ _FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
633
+ _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], callable] = {}
634
+
635
+
636
+ def _get_accumulate_fn(
637
+ cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
638
+ ) -> callable:
639
+ """Return an accumulation function with config-bound constants."""
640
+ use_lut = cfg.use_lut and device.type != "mps"
641
+ fdl = cfg.frac_delay_length
642
+ lut_gran = cfg.sinc_lut_granularity
643
+ chunk_size = cfg.accumulate_chunk_size
644
+
645
+ def _fn(rir: Tensor, sample: Tensor, amplitude: Tensor) -> None:
646
+ _accumulate_rir_batch_impl(
647
+ rir,
648
+ sample,
649
+ amplitude,
650
+ fdl=fdl,
651
+ lut_gran=lut_gran,
652
+ use_lut=use_lut,
653
+ chunk_size=chunk_size,
654
+ )
655
+
656
+ if device.type not in ("cuda", "mps") or not cfg.use_compile:
657
+ return _fn
658
+ key = (str(device), dtype, fdl, lut_gran, use_lut, chunk_size)
659
+ compiled = _ACCUM_BATCH_COMPILED.get(key)
660
+ if compiled is None:
661
+ compiled = torch.compile(_fn, dynamic=True)
662
+ _ACCUM_BATCH_COMPILED[key] = compiled
663
+ return compiled
664
+
665
+
666
+ def _get_fdl_grid(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
667
+ key = (fdl, str(device), dtype)
668
+ cached = _FDL_GRID_CACHE.get(key)
669
+ if cached is None:
670
+ cached = torch.arange(fdl, device=device, dtype=dtype)
671
+ _FDL_GRID_CACHE[key] = cached
672
+ return cached
673
+
674
+
675
+ def _get_fdl_offsets(fdl: int, *, device: torch.device) -> Tensor:
676
+ key = (fdl, str(device))
677
+ cached = _FDL_OFFSETS_CACHE.get(key)
678
+ if cached is None:
679
+ fdl2 = (fdl - 1) // 2
680
+ cached = torch.arange(fdl, device=device, dtype=torch.int64) - fdl2
681
+ _FDL_OFFSETS_CACHE[key] = cached
682
+ return cached
683
+
684
+
685
+ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
686
+ key = (fdl, str(device), dtype)
687
+ cached = _FDL_WINDOW_CACHE.get(key)
688
+ if cached is None:
689
+ cached = torch.hann_window(fdl, periodic=False, device=device, dtype=dtype)
690
+ _FDL_WINDOW_CACHE[key] = cached
691
+ return cached
692
+
693
+
694
+ def _get_sinc_lut(fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
695
+ """Create a sinc lookup table for fractional delays."""
696
+ key = (fdl, lut_gran, str(device), dtype)
697
+ cached = _SINC_LUT_CACHE.get(key)
698
+ if cached is not None:
699
+ return cached
700
+ fdl2 = (fdl - 1) // 2
701
+ lut_size = (fdl + 1) * lut_gran + 1
702
+ n = torch.linspace(-fdl2 - 1, fdl2 + 1, lut_size, device=device, dtype=dtype)
703
+ cached = torch.sinc(n)
704
+ _SINC_LUT_CACHE[key] = cached
705
+ return cached
706
+
707
+
708
+ def _apply_diffuse_tail(
709
+ rir: Tensor,
710
+ room: Room,
711
+ beta: Tensor,
712
+ tdiff: float,
713
+ tmax: float,
714
+ *,
715
+ seed: Optional[int] = None,
716
+ ) -> Tensor:
717
+ """Apply a diffuse reverberation tail after tdiff.
718
+
719
+ Returns:
720
+ RIR tensor with diffuse tail applied.
721
+ """
722
+ nsample = rir.shape[-1]
723
+ tdiff_idx = min(nsample - 1, int(math.floor(tdiff * room.fs)))
724
+ if tdiff_idx <= 0:
725
+ return rir
726
+ tail_len = nsample - tdiff_idx
727
+ t = torch.arange(tail_len, device=rir.device, dtype=rir.dtype) / room.fs
728
+
729
+ t60 = estimate_t60_from_beta(room.size, beta)
730
+ if math.isinf(t60) or t60 <= 0:
731
+ decay = torch.exp(-t * 3.0)
732
+ else:
733
+ tau = t60 / 6.9078
734
+ decay = torch.exp(-t / tau)
735
+
736
+ gen = torch.Generator(device=rir.device)
737
+ gen.manual_seed(0 if seed is None else seed)
738
+ noise = torch.randn(rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen)
739
+ scale = torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True) + 1e-8
740
+ rir[..., tdiff_idx:] = noise * decay * scale
741
+ return rir