torchrir 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +85 -0
- torchrir/config.py +59 -0
- torchrir/core.py +741 -0
- torchrir/datasets/__init__.py +27 -0
- torchrir/datasets/base.py +27 -0
- torchrir/datasets/cmu_arctic.py +204 -0
- torchrir/datasets/template.py +65 -0
- torchrir/datasets/utils.py +74 -0
- torchrir/directivity.py +33 -0
- torchrir/dynamic.py +60 -0
- torchrir/logging_utils.py +55 -0
- torchrir/plotting.py +210 -0
- torchrir/plotting_utils.py +173 -0
- torchrir/results.py +22 -0
- torchrir/room.py +150 -0
- torchrir/scene.py +67 -0
- torchrir/scene_utils.py +51 -0
- torchrir/signal.py +233 -0
- torchrir/simulators.py +86 -0
- torchrir/utils.py +281 -0
- torchrir-0.1.0.dist-info/METADATA +213 -0
- torchrir-0.1.0.dist-info/RECORD +26 -0
- torchrir-0.1.0.dist-info/WHEEL +5 -0
- torchrir-0.1.0.dist-info/licenses/LICENSE +190 -0
- torchrir-0.1.0.dist-info/licenses/NOTICE +4 -0
- torchrir-0.1.0.dist-info/top_level.txt +1 -0
torchrir/core.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Core RIR simulation functions (static and dynamic)."""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from .config import SimulationConfig, default_config
|
|
12
|
+
from .directivity import directivity_gain, split_directivity
|
|
13
|
+
from .room import MicrophoneArray, Room, Source
|
|
14
|
+
from .utils import (
|
|
15
|
+
as_tensor,
|
|
16
|
+
ensure_dim,
|
|
17
|
+
estimate_beta_from_t60,
|
|
18
|
+
estimate_t60_from_beta,
|
|
19
|
+
infer_device_dtype,
|
|
20
|
+
normalize_orientation,
|
|
21
|
+
orientation_to_unit,
|
|
22
|
+
resolve_device,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def simulate_rir(
|
|
27
|
+
*,
|
|
28
|
+
room: Room,
|
|
29
|
+
sources: Source | Tensor,
|
|
30
|
+
mics: MicrophoneArray | Tensor,
|
|
31
|
+
max_order: int | None,
|
|
32
|
+
nb_img: Optional[Tensor | Tuple[int, ...]] = None,
|
|
33
|
+
nsample: Optional[int] = None,
|
|
34
|
+
tmax: Optional[float] = None,
|
|
35
|
+
tdiff: Optional[float] = None,
|
|
36
|
+
directivity: str | tuple[str, str] | None = "omni",
|
|
37
|
+
orientation: Optional[Tensor | tuple[Tensor, Tensor]] = None,
|
|
38
|
+
config: Optional[SimulationConfig] = None,
|
|
39
|
+
device: Optional[torch.device | str] = None,
|
|
40
|
+
dtype: Optional[torch.dtype] = None,
|
|
41
|
+
) -> Tensor:
|
|
42
|
+
"""Simulate a static RIR using the image source method.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
room: Room configuration (geometry, fs, reflection coefficients).
|
|
46
|
+
sources: Source positions or a Source object.
|
|
47
|
+
mics: Microphone positions or a MicrophoneArray object.
|
|
48
|
+
max_order: Maximum reflection order (uses config if None).
|
|
49
|
+
nb_img: Optional per-dimension image counts (overrides max_order).
|
|
50
|
+
nsample: Output length in samples.
|
|
51
|
+
tmax: Output length in seconds (used if nsample is None).
|
|
52
|
+
tdiff: Optional time to start diffuse tail modeling.
|
|
53
|
+
directivity: Directivity pattern(s) for source and mic (uses config if None).
|
|
54
|
+
orientation: Orientation vectors or angles.
|
|
55
|
+
config: Optional simulation configuration overrides.
|
|
56
|
+
device: Output device.
|
|
57
|
+
dtype: Output dtype.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Tensor of shape (n_src, n_mic, nsample).
|
|
61
|
+
"""
|
|
62
|
+
cfg = config or default_config()
|
|
63
|
+
cfg.validate()
|
|
64
|
+
|
|
65
|
+
if device is None and cfg.device is not None:
|
|
66
|
+
device = cfg.device
|
|
67
|
+
|
|
68
|
+
if max_order is None:
|
|
69
|
+
if cfg.max_order is None:
|
|
70
|
+
raise ValueError("max_order must be provided if not set in config")
|
|
71
|
+
max_order = cfg.max_order
|
|
72
|
+
|
|
73
|
+
if tmax is None and nsample is None and cfg.tmax is not None:
|
|
74
|
+
tmax = cfg.tmax
|
|
75
|
+
|
|
76
|
+
if directivity is None:
|
|
77
|
+
directivity = cfg.directivity or "omni"
|
|
78
|
+
|
|
79
|
+
if not isinstance(room, Room):
|
|
80
|
+
raise TypeError("room must be a Room instance")
|
|
81
|
+
if nsample is None and tmax is None:
|
|
82
|
+
raise ValueError("nsample or tmax must be provided")
|
|
83
|
+
if nsample is None:
|
|
84
|
+
nsample = int(math.ceil(tmax * room.fs))
|
|
85
|
+
if nsample <= 0:
|
|
86
|
+
raise ValueError("nsample must be positive")
|
|
87
|
+
if max_order < 0:
|
|
88
|
+
raise ValueError("max_order must be non-negative")
|
|
89
|
+
|
|
90
|
+
if isinstance(device, str):
|
|
91
|
+
device = resolve_device(device)
|
|
92
|
+
|
|
93
|
+
src_pos, src_ori = _prepare_entities(
|
|
94
|
+
sources, orientation, which="source", device=device, dtype=dtype
|
|
95
|
+
)
|
|
96
|
+
mic_pos, mic_ori = _prepare_entities(
|
|
97
|
+
mics, orientation, which="mic", device=device, dtype=dtype
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
device, dtype = infer_device_dtype(
|
|
101
|
+
src_pos, mic_pos, room.size, device=device, dtype=dtype
|
|
102
|
+
)
|
|
103
|
+
src_pos = as_tensor(src_pos, device=device, dtype=dtype)
|
|
104
|
+
mic_pos = as_tensor(mic_pos, device=device, dtype=dtype)
|
|
105
|
+
|
|
106
|
+
if src_ori is not None:
|
|
107
|
+
src_ori = as_tensor(src_ori, device=device, dtype=dtype)
|
|
108
|
+
if mic_ori is not None:
|
|
109
|
+
mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
|
|
110
|
+
|
|
111
|
+
room_size = as_tensor(room.size, device=device, dtype=dtype)
|
|
112
|
+
room_size = ensure_dim(room_size)
|
|
113
|
+
dim = room_size.numel()
|
|
114
|
+
|
|
115
|
+
if src_pos.ndim == 1:
|
|
116
|
+
src_pos = src_pos.unsqueeze(0)
|
|
117
|
+
if mic_pos.ndim == 1:
|
|
118
|
+
mic_pos = mic_pos.unsqueeze(0)
|
|
119
|
+
if src_pos.ndim != 2 or src_pos.shape[1] != dim:
|
|
120
|
+
raise ValueError("sources must be of shape (n_src, dim)")
|
|
121
|
+
if mic_pos.ndim != 2 or mic_pos.shape[1] != dim:
|
|
122
|
+
raise ValueError("mics must be of shape (n_mic, dim)")
|
|
123
|
+
|
|
124
|
+
beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
|
|
125
|
+
beta = _validate_beta(beta, dim)
|
|
126
|
+
|
|
127
|
+
n_vec = _image_source_indices(max_order, dim, device=device, nb_img=nb_img)
|
|
128
|
+
refl = _reflection_coefficients(n_vec, beta)
|
|
129
|
+
|
|
130
|
+
src_pattern, mic_pattern = split_directivity(directivity)
|
|
131
|
+
mic_dir = None
|
|
132
|
+
if mic_pattern != "omni":
|
|
133
|
+
if mic_ori is None:
|
|
134
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
135
|
+
mic_dir = orientation_to_unit(mic_ori, dim)
|
|
136
|
+
|
|
137
|
+
n_src = src_pos.shape[0]
|
|
138
|
+
n_mic = mic_pos.shape[0]
|
|
139
|
+
rir = torch.zeros((n_src, n_mic, nsample), device=device, dtype=dtype)
|
|
140
|
+
fdl = cfg.frac_delay_length
|
|
141
|
+
fdl2 = (fdl - 1) // 2
|
|
142
|
+
img_chunk = cfg.image_chunk_size
|
|
143
|
+
if img_chunk <= 0:
|
|
144
|
+
img_chunk = n_vec.shape[0]
|
|
145
|
+
|
|
146
|
+
src_dirs = None
|
|
147
|
+
if src_pattern != "omni":
|
|
148
|
+
if src_ori is None:
|
|
149
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
150
|
+
src_dirs = orientation_to_unit(src_ori, dim)
|
|
151
|
+
if src_dirs.ndim == 1:
|
|
152
|
+
src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
|
|
153
|
+
if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
|
|
154
|
+
raise ValueError("source orientation must match number of sources")
|
|
155
|
+
|
|
156
|
+
for start in range(0, n_vec.shape[0], img_chunk):
|
|
157
|
+
end = min(start + img_chunk, n_vec.shape[0])
|
|
158
|
+
n_vec_chunk = n_vec[start:end]
|
|
159
|
+
refl_chunk = refl[start:end]
|
|
160
|
+
sample_chunk, attenuation_chunk = _compute_image_contributions_batch(
|
|
161
|
+
src_pos,
|
|
162
|
+
mic_pos,
|
|
163
|
+
room_size,
|
|
164
|
+
n_vec_chunk,
|
|
165
|
+
refl_chunk,
|
|
166
|
+
room,
|
|
167
|
+
fdl2,
|
|
168
|
+
src_pattern=src_pattern,
|
|
169
|
+
mic_pattern=mic_pattern,
|
|
170
|
+
src_dirs=src_dirs,
|
|
171
|
+
mic_dir=mic_dir,
|
|
172
|
+
)
|
|
173
|
+
_accumulate_rir_batch(rir, sample_chunk, attenuation_chunk, cfg)
|
|
174
|
+
|
|
175
|
+
if tdiff is not None and tmax is not None and tdiff < tmax:
|
|
176
|
+
rir = _apply_diffuse_tail(rir, room, beta, tdiff, tmax, seed=cfg.seed)
|
|
177
|
+
return rir
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def simulate_dynamic_rir(
|
|
181
|
+
*,
|
|
182
|
+
room: Room,
|
|
183
|
+
src_traj: Tensor,
|
|
184
|
+
mic_traj: Tensor,
|
|
185
|
+
max_order: int | None,
|
|
186
|
+
nsample: Optional[int] = None,
|
|
187
|
+
tmax: Optional[float] = None,
|
|
188
|
+
directivity: str | tuple[str, str] | None = "omni",
|
|
189
|
+
orientation: Optional[Tensor | tuple[Tensor, Tensor]] = None,
|
|
190
|
+
config: Optional[SimulationConfig] = None,
|
|
191
|
+
device: Optional[torch.device | str] = None,
|
|
192
|
+
dtype: Optional[torch.dtype] = None,
|
|
193
|
+
) -> Tensor:
|
|
194
|
+
"""Simulate time-varying RIRs for source/mic trajectories.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
room: Room configuration.
|
|
198
|
+
src_traj: Source trajectory (T, n_src, dim).
|
|
199
|
+
mic_traj: Microphone trajectory (T, n_mic, dim).
|
|
200
|
+
max_order: Maximum reflection order (uses config if None).
|
|
201
|
+
nsample: Output length in samples.
|
|
202
|
+
tmax: Output length in seconds (used if nsample is None).
|
|
203
|
+
directivity: Directivity pattern(s) for source and mic (uses config if None).
|
|
204
|
+
orientation: Orientation vectors or angles.
|
|
205
|
+
config: Optional simulation configuration overrides.
|
|
206
|
+
device: Output device.
|
|
207
|
+
dtype: Output dtype.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Tensor of shape (T, n_src, n_mic, nsample).
|
|
211
|
+
"""
|
|
212
|
+
cfg = config or default_config()
|
|
213
|
+
cfg.validate()
|
|
214
|
+
|
|
215
|
+
if device is None and cfg.device is not None:
|
|
216
|
+
device = cfg.device
|
|
217
|
+
|
|
218
|
+
if max_order is None:
|
|
219
|
+
if cfg.max_order is None:
|
|
220
|
+
raise ValueError("max_order must be provided if not set in config")
|
|
221
|
+
max_order = cfg.max_order
|
|
222
|
+
|
|
223
|
+
if tmax is None and nsample is None and cfg.tmax is not None:
|
|
224
|
+
tmax = cfg.tmax
|
|
225
|
+
|
|
226
|
+
if directivity is None:
|
|
227
|
+
directivity = cfg.directivity or "omni"
|
|
228
|
+
|
|
229
|
+
if isinstance(device, str):
|
|
230
|
+
device = resolve_device(device)
|
|
231
|
+
|
|
232
|
+
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
233
|
+
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
234
|
+
|
|
235
|
+
if src_traj.ndim == 2:
|
|
236
|
+
src_traj = src_traj.unsqueeze(1)
|
|
237
|
+
if mic_traj.ndim == 2:
|
|
238
|
+
mic_traj = mic_traj.unsqueeze(1)
|
|
239
|
+
if src_traj.ndim != 3:
|
|
240
|
+
raise ValueError("src_traj must be of shape (T, n_src, dim)")
|
|
241
|
+
if mic_traj.ndim != 3:
|
|
242
|
+
raise ValueError("mic_traj must be of shape (T, n_mic, dim)")
|
|
243
|
+
if src_traj.shape[0] != mic_traj.shape[0]:
|
|
244
|
+
raise ValueError("src_traj and mic_traj must have the same time length")
|
|
245
|
+
|
|
246
|
+
t_steps = src_traj.shape[0]
|
|
247
|
+
rirs = []
|
|
248
|
+
for t_idx in range(t_steps):
|
|
249
|
+
rir = simulate_rir(
|
|
250
|
+
room=room,
|
|
251
|
+
sources=src_traj[t_idx],
|
|
252
|
+
mics=mic_traj[t_idx],
|
|
253
|
+
max_order=max_order,
|
|
254
|
+
nsample=nsample,
|
|
255
|
+
tmax=tmax,
|
|
256
|
+
directivity=directivity,
|
|
257
|
+
orientation=orientation,
|
|
258
|
+
config=config,
|
|
259
|
+
device=device,
|
|
260
|
+
dtype=dtype,
|
|
261
|
+
)
|
|
262
|
+
rirs.append(rir)
|
|
263
|
+
return torch.stack(rirs, dim=0)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _prepare_entities(
|
|
267
|
+
entities: Source | MicrophoneArray | Tensor,
|
|
268
|
+
orientation: Optional[Tensor | tuple[Tensor, Tensor]],
|
|
269
|
+
*,
|
|
270
|
+
which: str,
|
|
271
|
+
device: Optional[torch.device | str],
|
|
272
|
+
dtype: Optional[torch.dtype],
|
|
273
|
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
274
|
+
"""Extract positions and orientations from entities or raw tensors.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Tuple of (positions, orientation).
|
|
278
|
+
"""
|
|
279
|
+
if isinstance(entities, (Source, MicrophoneArray)):
|
|
280
|
+
pos = entities.positions
|
|
281
|
+
ori = entities.orientation
|
|
282
|
+
else:
|
|
283
|
+
pos = entities
|
|
284
|
+
ori = None
|
|
285
|
+
if orientation is not None:
|
|
286
|
+
if isinstance(orientation, (list, tuple)):
|
|
287
|
+
if len(orientation) != 2:
|
|
288
|
+
raise ValueError("orientation tuple must have length 2")
|
|
289
|
+
ori = orientation[0] if which == "source" else orientation[1]
|
|
290
|
+
else:
|
|
291
|
+
ori = orientation
|
|
292
|
+
pos = as_tensor(pos, device=device, dtype=dtype)
|
|
293
|
+
if ori is not None:
|
|
294
|
+
ori = as_tensor(ori, device=device, dtype=dtype)
|
|
295
|
+
return pos, ori
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _resolve_beta(
|
|
299
|
+
room: Room, room_size: Tensor, *, device: torch.device, dtype: torch.dtype
|
|
300
|
+
) -> Tensor:
|
|
301
|
+
"""Resolve reflection coefficients from beta/t60/defaults.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Tensor of reflection coefficients per wall.
|
|
305
|
+
"""
|
|
306
|
+
if room.beta is not None:
|
|
307
|
+
return as_tensor(room.beta, device=device, dtype=dtype)
|
|
308
|
+
if room.t60 is not None:
|
|
309
|
+
return estimate_beta_from_t60(room_size, room.t60, device=device, dtype=dtype)
|
|
310
|
+
dim = room_size.numel()
|
|
311
|
+
default_faces = 4 if dim == 2 else 6
|
|
312
|
+
return torch.ones((default_faces,), device=device, dtype=dtype)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _validate_beta(beta: Tensor, dim: int) -> Tensor:
|
|
316
|
+
"""Validate beta size against room dimension.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
The validated beta tensor.
|
|
320
|
+
"""
|
|
321
|
+
expected = 4 if dim == 2 else 6
|
|
322
|
+
if beta.numel() != expected:
|
|
323
|
+
raise ValueError(f"beta must have {expected} elements for {dim}D")
|
|
324
|
+
return beta
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _image_source_indices(
|
|
328
|
+
max_order: int,
|
|
329
|
+
dim: int,
|
|
330
|
+
*,
|
|
331
|
+
device: torch.device,
|
|
332
|
+
nb_img: Optional[Tensor | Tuple[int, ...]] = None,
|
|
333
|
+
) -> Tensor:
|
|
334
|
+
"""Generate image source index vectors up to the given order.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Tensor of shape (n_images, dim).
|
|
338
|
+
"""
|
|
339
|
+
if nb_img is not None:
|
|
340
|
+
nb = as_tensor(nb_img, device=device, dtype=torch.int64)
|
|
341
|
+
if nb.numel() != dim:
|
|
342
|
+
raise ValueError("nb_img must match room dimension")
|
|
343
|
+
ranges = [torch.arange(-n, n + 1, device=device, dtype=torch.int64) for n in nb]
|
|
344
|
+
grids = torch.meshgrid(*ranges, indexing="ij")
|
|
345
|
+
return torch.stack([g.reshape(-1) for g in grids], dim=-1)
|
|
346
|
+
rng = torch.arange(-max_order, max_order + 1, device=device, dtype=torch.int64)
|
|
347
|
+
grids = torch.meshgrid(*([rng] * dim), indexing="ij")
|
|
348
|
+
n_vec = torch.stack([g.reshape(-1) for g in grids], dim=-1)
|
|
349
|
+
order = torch.sum(torch.abs(n_vec), dim=-1)
|
|
350
|
+
return n_vec[order <= max_order]
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _image_positions(src: Tensor, room_size: Tensor, n_vec: Tensor) -> Tensor:
|
|
354
|
+
"""Compute image source positions for a given source.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Tensor of image positions (n_images, dim).
|
|
358
|
+
"""
|
|
359
|
+
n_vec_f = n_vec.to(dtype=src.dtype)
|
|
360
|
+
sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src.dtype)
|
|
361
|
+
n = torch.floor_divide(n_vec + 1, 2).to(dtype=src.dtype)
|
|
362
|
+
return 2.0 * room_size * n + sign * src
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _image_positions_batch(src_pos: Tensor, room_size: Tensor, n_vec: Tensor) -> Tensor:
|
|
366
|
+
"""Compute image source positions for multiple sources."""
|
|
367
|
+
sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_pos.dtype)
|
|
368
|
+
n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_pos.dtype)
|
|
369
|
+
base = 2.0 * room_size * n
|
|
370
|
+
return base[None, :, :] + sign[None, :, :] * src_pos[:, None, :]
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _reflection_coefficients(n_vec: Tensor, beta: Tensor) -> Tensor:
|
|
374
|
+
"""Compute reflection coefficients for each image source.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Tensor of shape (n_images,) with per-image gains.
|
|
378
|
+
"""
|
|
379
|
+
dim = n_vec.shape[1]
|
|
380
|
+
beta = beta.view(dim, 2)
|
|
381
|
+
beta_lo = beta[:, 0]
|
|
382
|
+
beta_hi = beta[:, 1]
|
|
383
|
+
|
|
384
|
+
n = n_vec
|
|
385
|
+
k = torch.abs(n)
|
|
386
|
+
n_hi = torch.where(n >= 0, (n + 1) // 2, k // 2)
|
|
387
|
+
n_lo = torch.where(n >= 0, n // 2, (k + 1) // 2)
|
|
388
|
+
|
|
389
|
+
n_hi = n_hi.to(dtype=beta.dtype)
|
|
390
|
+
n_lo = n_lo.to(dtype=beta.dtype)
|
|
391
|
+
|
|
392
|
+
coeff = (beta_hi**n_hi) * (beta_lo**n_lo)
|
|
393
|
+
return torch.prod(coeff, dim=1)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _compute_image_contributions(
|
|
397
|
+
src: Tensor,
|
|
398
|
+
mic_pos: Tensor,
|
|
399
|
+
room_size: Tensor,
|
|
400
|
+
n_vec: Tensor,
|
|
401
|
+
refl: Tensor,
|
|
402
|
+
room: Room,
|
|
403
|
+
fdl2: int,
|
|
404
|
+
*,
|
|
405
|
+
src_pattern: str,
|
|
406
|
+
mic_pattern: str,
|
|
407
|
+
src_dir: Optional[Tensor],
|
|
408
|
+
mic_dir: Optional[Tensor],
|
|
409
|
+
) -> Tuple[Tensor, Tensor]:
|
|
410
|
+
"""Compute sample positions and attenuation for a source and all mics."""
|
|
411
|
+
img = _image_positions(src, room_size, n_vec)
|
|
412
|
+
vec = mic_pos[:, None, :] - img[None, :, :]
|
|
413
|
+
dist = torch.linalg.norm(vec, dim=-1)
|
|
414
|
+
dist = torch.clamp(dist, min=1e-6)
|
|
415
|
+
time = dist / room.c
|
|
416
|
+
time = time + (fdl2 / room.fs)
|
|
417
|
+
sample = time * room.fs
|
|
418
|
+
|
|
419
|
+
gain = refl[None, :]
|
|
420
|
+
if src_pattern != "omni":
|
|
421
|
+
if src_dir is None:
|
|
422
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
423
|
+
cos_theta = _cos_between(vec, src_dir)
|
|
424
|
+
gain = gain * directivity_gain(src_pattern, cos_theta)
|
|
425
|
+
if mic_pattern != "omni":
|
|
426
|
+
if mic_dir is None:
|
|
427
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
428
|
+
cos_theta = _cos_between(-vec, mic_dir)
|
|
429
|
+
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
430
|
+
|
|
431
|
+
attenuation = gain / dist
|
|
432
|
+
return sample, attenuation
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _compute_image_contributions_batch(
|
|
436
|
+
src_pos: Tensor,
|
|
437
|
+
mic_pos: Tensor,
|
|
438
|
+
room_size: Tensor,
|
|
439
|
+
n_vec: Tensor,
|
|
440
|
+
refl: Tensor,
|
|
441
|
+
room: Room,
|
|
442
|
+
fdl2: int,
|
|
443
|
+
*,
|
|
444
|
+
src_pattern: str,
|
|
445
|
+
mic_pattern: str,
|
|
446
|
+
src_dirs: Optional[Tensor],
|
|
447
|
+
mic_dir: Optional[Tensor],
|
|
448
|
+
) -> Tuple[Tensor, Tensor]:
|
|
449
|
+
"""Compute samples/attenuation for all sources/mics/images in batch."""
|
|
450
|
+
img = _image_positions_batch(src_pos, room_size, n_vec)
|
|
451
|
+
vec = mic_pos[None, :, None, :] - img[:, None, :, :]
|
|
452
|
+
dist = torch.linalg.norm(vec, dim=-1)
|
|
453
|
+
dist = torch.clamp(dist, min=1e-6)
|
|
454
|
+
time = dist / room.c
|
|
455
|
+
time = time + (fdl2 / room.fs)
|
|
456
|
+
sample = time * room.fs
|
|
457
|
+
|
|
458
|
+
gain = refl.view(1, 1, -1)
|
|
459
|
+
if src_pattern != "omni":
|
|
460
|
+
if src_dirs is None:
|
|
461
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
462
|
+
src_dirs = src_dirs[:, None, None, :]
|
|
463
|
+
cos_theta = _cos_between(vec, src_dirs)
|
|
464
|
+
gain = gain * directivity_gain(src_pattern, cos_theta)
|
|
465
|
+
if mic_pattern != "omni":
|
|
466
|
+
if mic_dir is None:
|
|
467
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
468
|
+
mic_dir = mic_dir[None, :, None, :] if mic_dir.ndim == 2 else mic_dir.view(1, 1, 1, -1)
|
|
469
|
+
cos_theta = _cos_between(-vec, mic_dir)
|
|
470
|
+
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
471
|
+
|
|
472
|
+
attenuation = gain / dist
|
|
473
|
+
return sample, attenuation
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
|
|
477
|
+
"""Pick the correct orientation vector for a given entity index."""
|
|
478
|
+
if orientation.ndim == 0:
|
|
479
|
+
return orientation_to_unit(orientation, dim)
|
|
480
|
+
if orientation.ndim == 1:
|
|
481
|
+
return orientation_to_unit(orientation, dim)
|
|
482
|
+
if orientation.ndim == 2 and orientation.shape[0] == count:
|
|
483
|
+
return orientation_to_unit(orientation[idx], dim)
|
|
484
|
+
raise ValueError("orientation must be shape (dim,), (count, dim), or angles")
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def _cos_between(vec: Tensor, orientation: Tensor) -> Tensor:
|
|
488
|
+
"""Compute cosine between direction vectors and orientation."""
|
|
489
|
+
orientation = normalize_orientation(orientation)
|
|
490
|
+
unit = vec / torch.linalg.norm(vec, dim=-1, keepdim=True)
|
|
491
|
+
return torch.sum(unit * orientation, dim=-1)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _accumulate_rir(
|
|
495
|
+
rir: Tensor, sample: Tensor, amplitude: Tensor, cfg: SimulationConfig
|
|
496
|
+
) -> None:
|
|
497
|
+
"""Accumulate fractional-delay contributions into the RIR tensor."""
|
|
498
|
+
idx0 = torch.floor(sample).to(torch.int64)
|
|
499
|
+
frac = sample - idx0.to(sample.dtype)
|
|
500
|
+
|
|
501
|
+
n_mic, nsample = rir.shape
|
|
502
|
+
fdl = cfg.frac_delay_length
|
|
503
|
+
lut_gran = cfg.sinc_lut_granularity
|
|
504
|
+
use_lut = cfg.use_lut and rir.device.type != "mps"
|
|
505
|
+
fdl2 = (fdl - 1) // 2
|
|
506
|
+
|
|
507
|
+
dtype = amplitude.dtype
|
|
508
|
+
n = _get_fdl_grid(fdl, device=rir.device, dtype=dtype)
|
|
509
|
+
offsets = _get_fdl_offsets(fdl, device=rir.device)
|
|
510
|
+
window = _get_fdl_window(fdl, device=rir.device, dtype=dtype)
|
|
511
|
+
|
|
512
|
+
if use_lut:
|
|
513
|
+
sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
|
|
514
|
+
|
|
515
|
+
mic_offsets = (torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample).view(
|
|
516
|
+
n_mic, 1, 1
|
|
517
|
+
)
|
|
518
|
+
rir_flat = rir.view(-1)
|
|
519
|
+
|
|
520
|
+
chunk_size = cfg.accumulate_chunk_size
|
|
521
|
+
n_img = idx0.shape[1]
|
|
522
|
+
for start in range(0, n_img, chunk_size):
|
|
523
|
+
end = min(start + chunk_size, n_img)
|
|
524
|
+
idx = idx0[:, start:end]
|
|
525
|
+
amp = amplitude[:, start:end]
|
|
526
|
+
frac_m = frac[:, start:end]
|
|
527
|
+
|
|
528
|
+
if use_lut:
|
|
529
|
+
x_off_frac = (1.0 - frac_m) * lut_gran
|
|
530
|
+
lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
|
|
531
|
+
x_off = x_off_frac - lut_gran_off.to(dtype)
|
|
532
|
+
lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
|
|
533
|
+
|
|
534
|
+
s0 = torch.take(sinc_lut, lut_pos)
|
|
535
|
+
s1 = torch.take(sinc_lut, lut_pos + 1)
|
|
536
|
+
interp = s0 + x_off[..., None] * (s1 - s0)
|
|
537
|
+
filt = interp * window[None, None, :]
|
|
538
|
+
else:
|
|
539
|
+
t = n[None, None, :] - fdl2 - frac_m[..., None]
|
|
540
|
+
filt = torch.sinc(t) * window[None, None, :]
|
|
541
|
+
|
|
542
|
+
contrib = amp[..., None] * filt
|
|
543
|
+
target = idx[..., None] + offsets[None, None, :]
|
|
544
|
+
valid = (target >= 0) & (target < nsample)
|
|
545
|
+
if not valid.any():
|
|
546
|
+
continue
|
|
547
|
+
|
|
548
|
+
target = target + mic_offsets
|
|
549
|
+
target_flat = target[valid].to(torch.int64)
|
|
550
|
+
values_flat = contrib[valid]
|
|
551
|
+
rir_flat.scatter_add_(0, target_flat, values_flat)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def _accumulate_rir_batch(
|
|
555
|
+
rir: Tensor, sample: Tensor, amplitude: Tensor, cfg: SimulationConfig
|
|
556
|
+
) -> None:
|
|
557
|
+
"""Accumulate fractional-delay contributions for all sources/mics."""
|
|
558
|
+
fn = _get_accumulate_fn(cfg, rir.device, amplitude.dtype)
|
|
559
|
+
return fn(rir, sample, amplitude)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _accumulate_rir_batch_impl(
|
|
563
|
+
rir: Tensor,
|
|
564
|
+
sample: Tensor,
|
|
565
|
+
amplitude: Tensor,
|
|
566
|
+
*,
|
|
567
|
+
fdl: int,
|
|
568
|
+
lut_gran: int,
|
|
569
|
+
use_lut: bool,
|
|
570
|
+
chunk_size: int,
|
|
571
|
+
) -> None:
|
|
572
|
+
"""Implementation for batch accumulation (optionally compiled)."""
|
|
573
|
+
idx0 = torch.floor(sample).to(torch.int64)
|
|
574
|
+
frac = sample - idx0.to(sample.dtype)
|
|
575
|
+
|
|
576
|
+
n_src, n_mic, nsample = rir.shape
|
|
577
|
+
n_sm = n_src * n_mic
|
|
578
|
+
idx0 = idx0.view(n_sm, -1)
|
|
579
|
+
frac = frac.view(n_sm, -1)
|
|
580
|
+
amplitude = amplitude.view(n_sm, -1)
|
|
581
|
+
|
|
582
|
+
fdl2 = (fdl - 1) // 2
|
|
583
|
+
|
|
584
|
+
n = _get_fdl_grid(fdl, device=rir.device, dtype=sample.dtype)
|
|
585
|
+
offsets = _get_fdl_offsets(fdl, device=rir.device)
|
|
586
|
+
window = _get_fdl_window(fdl, device=rir.device, dtype=sample.dtype)
|
|
587
|
+
|
|
588
|
+
if use_lut:
|
|
589
|
+
sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
|
|
590
|
+
|
|
591
|
+
sm_offsets = (torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample).view(
|
|
592
|
+
n_sm, 1, 1
|
|
593
|
+
)
|
|
594
|
+
rir_flat = rir.view(-1)
|
|
595
|
+
|
|
596
|
+
n_img = idx0.shape[1]
|
|
597
|
+
for start in range(0, n_img, chunk_size):
|
|
598
|
+
end = min(start + chunk_size, n_img)
|
|
599
|
+
idx = idx0[:, start:end]
|
|
600
|
+
amp = amplitude[:, start:end]
|
|
601
|
+
frac_m = frac[:, start:end]
|
|
602
|
+
|
|
603
|
+
if use_lut:
|
|
604
|
+
x_off_frac = (1.0 - frac_m) * lut_gran
|
|
605
|
+
lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
|
|
606
|
+
x_off = x_off_frac - lut_gran_off.to(sample.dtype)
|
|
607
|
+
lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
|
|
608
|
+
|
|
609
|
+
s0 = torch.take(sinc_lut, lut_pos)
|
|
610
|
+
s1 = torch.take(sinc_lut, lut_pos + 1)
|
|
611
|
+
interp = s0 + x_off[..., None] * (s1 - s0)
|
|
612
|
+
filt = interp * window[None, None, :]
|
|
613
|
+
else:
|
|
614
|
+
t = n[None, None, :] - fdl2 - frac_m[..., None]
|
|
615
|
+
filt = torch.sinc(t) * window[None, None, :]
|
|
616
|
+
|
|
617
|
+
contrib = amp[..., None] * filt
|
|
618
|
+
target = idx[..., None] + offsets[None, None, :]
|
|
619
|
+
valid = (target >= 0) & (target < nsample)
|
|
620
|
+
if not valid.any():
|
|
621
|
+
continue
|
|
622
|
+
|
|
623
|
+
target = target + sm_offsets
|
|
624
|
+
target_flat = target[valid].to(torch.int64)
|
|
625
|
+
values_flat = contrib[valid]
|
|
626
|
+
rir_flat.scatter_add_(0, target_flat, values_flat)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
_SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
|
|
630
|
+
_FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
|
|
631
|
+
_FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
|
|
632
|
+
_FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
|
|
633
|
+
_ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], callable] = {}
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def _get_accumulate_fn(
|
|
637
|
+
cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
|
|
638
|
+
) -> callable:
|
|
639
|
+
"""Return an accumulation function with config-bound constants."""
|
|
640
|
+
use_lut = cfg.use_lut and device.type != "mps"
|
|
641
|
+
fdl = cfg.frac_delay_length
|
|
642
|
+
lut_gran = cfg.sinc_lut_granularity
|
|
643
|
+
chunk_size = cfg.accumulate_chunk_size
|
|
644
|
+
|
|
645
|
+
def _fn(rir: Tensor, sample: Tensor, amplitude: Tensor) -> None:
|
|
646
|
+
_accumulate_rir_batch_impl(
|
|
647
|
+
rir,
|
|
648
|
+
sample,
|
|
649
|
+
amplitude,
|
|
650
|
+
fdl=fdl,
|
|
651
|
+
lut_gran=lut_gran,
|
|
652
|
+
use_lut=use_lut,
|
|
653
|
+
chunk_size=chunk_size,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
if device.type not in ("cuda", "mps") or not cfg.use_compile:
|
|
657
|
+
return _fn
|
|
658
|
+
key = (str(device), dtype, fdl, lut_gran, use_lut, chunk_size)
|
|
659
|
+
compiled = _ACCUM_BATCH_COMPILED.get(key)
|
|
660
|
+
if compiled is None:
|
|
661
|
+
compiled = torch.compile(_fn, dynamic=True)
|
|
662
|
+
_ACCUM_BATCH_COMPILED[key] = compiled
|
|
663
|
+
return compiled
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _get_fdl_grid(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
|
|
667
|
+
key = (fdl, str(device), dtype)
|
|
668
|
+
cached = _FDL_GRID_CACHE.get(key)
|
|
669
|
+
if cached is None:
|
|
670
|
+
cached = torch.arange(fdl, device=device, dtype=dtype)
|
|
671
|
+
_FDL_GRID_CACHE[key] = cached
|
|
672
|
+
return cached
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _get_fdl_offsets(fdl: int, *, device: torch.device) -> Tensor:
|
|
676
|
+
key = (fdl, str(device))
|
|
677
|
+
cached = _FDL_OFFSETS_CACHE.get(key)
|
|
678
|
+
if cached is None:
|
|
679
|
+
fdl2 = (fdl - 1) // 2
|
|
680
|
+
cached = torch.arange(fdl, device=device, dtype=torch.int64) - fdl2
|
|
681
|
+
_FDL_OFFSETS_CACHE[key] = cached
|
|
682
|
+
return cached
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
|
|
686
|
+
key = (fdl, str(device), dtype)
|
|
687
|
+
cached = _FDL_WINDOW_CACHE.get(key)
|
|
688
|
+
if cached is None:
|
|
689
|
+
cached = torch.hann_window(fdl, periodic=False, device=device, dtype=dtype)
|
|
690
|
+
_FDL_WINDOW_CACHE[key] = cached
|
|
691
|
+
return cached
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def _get_sinc_lut(fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
|
|
695
|
+
"""Create a sinc lookup table for fractional delays."""
|
|
696
|
+
key = (fdl, lut_gran, str(device), dtype)
|
|
697
|
+
cached = _SINC_LUT_CACHE.get(key)
|
|
698
|
+
if cached is not None:
|
|
699
|
+
return cached
|
|
700
|
+
fdl2 = (fdl - 1) // 2
|
|
701
|
+
lut_size = (fdl + 1) * lut_gran + 1
|
|
702
|
+
n = torch.linspace(-fdl2 - 1, fdl2 + 1, lut_size, device=device, dtype=dtype)
|
|
703
|
+
cached = torch.sinc(n)
|
|
704
|
+
_SINC_LUT_CACHE[key] = cached
|
|
705
|
+
return cached
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def _apply_diffuse_tail(
|
|
709
|
+
rir: Tensor,
|
|
710
|
+
room: Room,
|
|
711
|
+
beta: Tensor,
|
|
712
|
+
tdiff: float,
|
|
713
|
+
tmax: float,
|
|
714
|
+
*,
|
|
715
|
+
seed: Optional[int] = None,
|
|
716
|
+
) -> Tensor:
|
|
717
|
+
"""Apply a diffuse reverberation tail after tdiff.
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
RIR tensor with diffuse tail applied.
|
|
721
|
+
"""
|
|
722
|
+
nsample = rir.shape[-1]
|
|
723
|
+
tdiff_idx = min(nsample - 1, int(math.floor(tdiff * room.fs)))
|
|
724
|
+
if tdiff_idx <= 0:
|
|
725
|
+
return rir
|
|
726
|
+
tail_len = nsample - tdiff_idx
|
|
727
|
+
t = torch.arange(tail_len, device=rir.device, dtype=rir.dtype) / room.fs
|
|
728
|
+
|
|
729
|
+
t60 = estimate_t60_from_beta(room.size, beta)
|
|
730
|
+
if math.isinf(t60) or t60 <= 0:
|
|
731
|
+
decay = torch.exp(-t * 3.0)
|
|
732
|
+
else:
|
|
733
|
+
tau = t60 / 6.9078
|
|
734
|
+
decay = torch.exp(-t / tau)
|
|
735
|
+
|
|
736
|
+
gen = torch.Generator(device=rir.device)
|
|
737
|
+
gen.manual_seed(0 if seed is None else seed)
|
|
738
|
+
noise = torch.randn(rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen)
|
|
739
|
+
scale = torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True) + 1e-8
|
|
740
|
+
rir[..., tdiff_idx:] = noise * decay * scale
|
|
741
|
+
return rir
|