torchrir 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +85 -0
- torchrir/config.py +59 -0
- torchrir/core.py +741 -0
- torchrir/datasets/__init__.py +27 -0
- torchrir/datasets/base.py +27 -0
- torchrir/datasets/cmu_arctic.py +204 -0
- torchrir/datasets/template.py +65 -0
- torchrir/datasets/utils.py +74 -0
- torchrir/directivity.py +33 -0
- torchrir/dynamic.py +60 -0
- torchrir/logging_utils.py +55 -0
- torchrir/plotting.py +210 -0
- torchrir/plotting_utils.py +173 -0
- torchrir/results.py +22 -0
- torchrir/room.py +150 -0
- torchrir/scene.py +67 -0
- torchrir/scene_utils.py +51 -0
- torchrir/signal.py +233 -0
- torchrir/simulators.py +86 -0
- torchrir/utils.py +281 -0
- torchrir-0.1.0.dist-info/METADATA +213 -0
- torchrir-0.1.0.dist-info/RECORD +26 -0
- torchrir-0.1.0.dist-info/WHEEL +5 -0
- torchrir-0.1.0.dist-info/licenses/LICENSE +190 -0
- torchrir-0.1.0.dist-info/licenses/NOTICE +4 -0
- torchrir-0.1.0.dist-info/top_level.txt +1 -0
torchrir/signal.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Signal convolution utilities.
|
|
4
|
+
|
|
5
|
+
Static convolution helpers are public. Dynamic convolution is exposed via
|
|
6
|
+
DynamicConvolver; internal dynamic helpers are prefixed with `_`.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fft_convolve(signal: Tensor, rir: Tensor) -> Tensor:
|
|
16
|
+
"""Convolve a 1D signal with a 1D RIR using FFT.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
signal: 1D signal tensor.
|
|
20
|
+
rir: 1D impulse response.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
1D tensor of length len(signal) + len(rir) - 1.
|
|
24
|
+
"""
|
|
25
|
+
if signal.ndim != 1 or rir.ndim != 1:
|
|
26
|
+
raise ValueError("fft_convolve expects 1D tensors")
|
|
27
|
+
n = signal.numel() + rir.numel() - 1
|
|
28
|
+
fft_len = 1 << (n - 1).bit_length()
|
|
29
|
+
sig_f = torch.fft.rfft(signal, n=fft_len)
|
|
30
|
+
rir_f = torch.fft.rfft(rir, n=fft_len)
|
|
31
|
+
out = torch.fft.irfft(sig_f * rir_f, n=fft_len)
|
|
32
|
+
return out[:n]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convolve_rir(signal: Tensor, rirs: Tensor) -> Tensor:
|
|
36
|
+
"""Convolve signals with static RIRs (supports multi-source/mic).
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
signal: (n_src, n_samples) or (n_samples,) tensor.
|
|
40
|
+
rirs: (n_src, n_mic, rir_len) or compatible shape.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
(n_mic, n_samples + rir_len - 1) tensor or 1D for single mic.
|
|
44
|
+
"""
|
|
45
|
+
signal = _ensure_signal(signal)
|
|
46
|
+
rirs = _ensure_static_rirs(rirs)
|
|
47
|
+
n_src, n_mic, rir_len = rirs.shape
|
|
48
|
+
|
|
49
|
+
if signal.shape[0] not in (1, n_src):
|
|
50
|
+
raise ValueError("signal source count does not match rirs")
|
|
51
|
+
if signal.shape[0] == 1 and n_src > 1:
|
|
52
|
+
signal = signal.expand(n_src, -1)
|
|
53
|
+
|
|
54
|
+
out_len = signal.shape[1] + rir_len - 1
|
|
55
|
+
out = torch.zeros((n_mic, out_len), dtype=signal.dtype, device=signal.device)
|
|
56
|
+
|
|
57
|
+
for s in range(n_src):
|
|
58
|
+
for m in range(n_mic):
|
|
59
|
+
out[m] += fft_convolve(signal[s], rirs[s, m])
|
|
60
|
+
|
|
61
|
+
return out.squeeze(0) if n_mic == 1 else out
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _convolve_dynamic_rir_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
|
|
65
|
+
"""Dynamic convolution using fixed hop-size segments."""
|
|
66
|
+
t_steps, n_src, n_mic, rir_len = rirs.shape
|
|
67
|
+
|
|
68
|
+
frames = math.ceil(signal.shape[1] / hop)
|
|
69
|
+
frames = min(frames, t_steps)
|
|
70
|
+
|
|
71
|
+
out_len = hop * (frames - 1) + hop + rir_len - 1
|
|
72
|
+
out = torch.zeros((n_mic, out_len), dtype=signal.dtype, device=signal.device)
|
|
73
|
+
|
|
74
|
+
for t in range(frames):
|
|
75
|
+
start = t * hop
|
|
76
|
+
for s in range(n_src):
|
|
77
|
+
frame = signal[s, start : start + hop]
|
|
78
|
+
if frame.numel() == 0:
|
|
79
|
+
continue
|
|
80
|
+
for m in range(n_mic):
|
|
81
|
+
seg = fft_convolve(frame, rirs[t, s, m])
|
|
82
|
+
out[m, start : start + seg.numel()] += seg
|
|
83
|
+
|
|
84
|
+
return out.squeeze(0) if n_mic == 1 else out
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _convolve_dynamic_rir_trajectory(
|
|
88
|
+
signal: Tensor,
|
|
89
|
+
rirs: Tensor,
|
|
90
|
+
*,
|
|
91
|
+
timestamps: Tensor | None,
|
|
92
|
+
fs: float | None,
|
|
93
|
+
) -> Tensor:
|
|
94
|
+
"""Dynamic convolution using variable segments like gpuRIR simulateTrajectory."""
|
|
95
|
+
n_samples = signal.shape[1]
|
|
96
|
+
t_steps, n_src, n_mic, rir_len = rirs.shape
|
|
97
|
+
|
|
98
|
+
if timestamps is not None:
|
|
99
|
+
if fs is None:
|
|
100
|
+
raise ValueError("fs must be provided when timestamps are used")
|
|
101
|
+
ts = torch.as_tensor(
|
|
102
|
+
timestamps,
|
|
103
|
+
device=signal.device,
|
|
104
|
+
dtype=torch.float32 if signal.device.type == "mps" else torch.float64,
|
|
105
|
+
)
|
|
106
|
+
if ts.ndim != 1 or ts.numel() != t_steps:
|
|
107
|
+
raise ValueError("timestamps must be 1D and match number of RIR steps")
|
|
108
|
+
if ts[0].item() != 0:
|
|
109
|
+
raise ValueError("first timestamp must be 0")
|
|
110
|
+
w_ini = (ts * fs).to(torch.long)
|
|
111
|
+
else:
|
|
112
|
+
step_fs = n_samples / t_steps
|
|
113
|
+
ts_dtype = torch.float32 if signal.device.type == "mps" else torch.float64
|
|
114
|
+
w_ini = (torch.arange(t_steps, device=signal.device, dtype=ts_dtype) * step_fs).to(
|
|
115
|
+
torch.long
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
w_ini = torch.cat(
|
|
119
|
+
[w_ini, torch.tensor([n_samples], device=signal.device, dtype=torch.long)]
|
|
120
|
+
)
|
|
121
|
+
w_len = w_ini[1:] - w_ini[:-1]
|
|
122
|
+
|
|
123
|
+
if signal.device.type in ("cuda", "mps"):
|
|
124
|
+
return _convolve_dynamic_rir_trajectory_batched(
|
|
125
|
+
signal, rirs, w_ini=w_ini, w_len=w_len
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
max_len = int(w_len.max().item())
|
|
129
|
+
segments = torch.zeros((t_steps, n_src, max_len), dtype=signal.dtype, device=signal.device)
|
|
130
|
+
for t in range(t_steps):
|
|
131
|
+
start = int(w_ini[t].item())
|
|
132
|
+
end = int(w_ini[t + 1].item())
|
|
133
|
+
if end > start:
|
|
134
|
+
segments[t, :, : end - start] = signal[:, start:end]
|
|
135
|
+
|
|
136
|
+
out = torch.zeros((n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device)
|
|
137
|
+
|
|
138
|
+
for t in range(t_steps):
|
|
139
|
+
seg_len = int(w_len[t].item())
|
|
140
|
+
if seg_len == 0:
|
|
141
|
+
continue
|
|
142
|
+
start = int(w_ini[t].item())
|
|
143
|
+
for s in range(n_src):
|
|
144
|
+
frame = segments[t, s, :seg_len]
|
|
145
|
+
for m in range(n_mic):
|
|
146
|
+
conv = fft_convolve(frame, rirs[t, s, m])
|
|
147
|
+
out[m, start : start + seg_len + rir_len - 1] += conv
|
|
148
|
+
|
|
149
|
+
return out.squeeze(0) if n_mic == 1 else out
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _convolve_dynamic_rir_trajectory_batched(
|
|
153
|
+
signal: Tensor,
|
|
154
|
+
rirs: Tensor,
|
|
155
|
+
*,
|
|
156
|
+
w_ini: Tensor,
|
|
157
|
+
w_len: Tensor,
|
|
158
|
+
chunk_size: int = 8,
|
|
159
|
+
) -> Tensor:
|
|
160
|
+
"""GPU-friendly batched trajectory convolution using FFT."""
|
|
161
|
+
n_samples = signal.shape[1]
|
|
162
|
+
t_steps, n_src, n_mic, rir_len = rirs.shape
|
|
163
|
+
out = torch.zeros((n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device)
|
|
164
|
+
|
|
165
|
+
for t0 in range(0, t_steps, chunk_size):
|
|
166
|
+
t1 = min(t0 + chunk_size, t_steps)
|
|
167
|
+
lengths = w_len[t0:t1]
|
|
168
|
+
max_len = int(lengths.max().item())
|
|
169
|
+
if max_len == 0:
|
|
170
|
+
continue
|
|
171
|
+
segments = torch.zeros((t1 - t0, n_src, max_len), dtype=signal.dtype, device=signal.device)
|
|
172
|
+
for idx, t in enumerate(range(t0, t1)):
|
|
173
|
+
start = int(w_ini[t].item())
|
|
174
|
+
end = int(w_ini[t + 1].item())
|
|
175
|
+
if end > start:
|
|
176
|
+
segments[idx, :, : end - start] = signal[:, start:end]
|
|
177
|
+
|
|
178
|
+
conv_len = max_len + rir_len - 1
|
|
179
|
+
fft_len = 1 << (conv_len - 1).bit_length()
|
|
180
|
+
seg_f = torch.fft.rfft(segments, n=fft_len, dim=-1)
|
|
181
|
+
rir_f = torch.fft.rfft(rirs[t0:t1], n=fft_len, dim=-1)
|
|
182
|
+
conv_out = torch.empty(
|
|
183
|
+
(t1 - t0, n_src, n_mic, fft_len),
|
|
184
|
+
dtype=signal.dtype,
|
|
185
|
+
device=signal.device,
|
|
186
|
+
)
|
|
187
|
+
conv = torch.fft.irfft(seg_f[:, :, None, :] * rir_f, n=fft_len, dim=-1, out=conv_out)
|
|
188
|
+
conv = conv[..., :conv_len]
|
|
189
|
+
conv_sum = conv.sum(dim=1)
|
|
190
|
+
|
|
191
|
+
for idx, t in enumerate(range(t0, t1)):
|
|
192
|
+
seg_len = int(lengths[idx].item())
|
|
193
|
+
if seg_len == 0:
|
|
194
|
+
continue
|
|
195
|
+
start = int(w_ini[t].item())
|
|
196
|
+
out[:, start : start + seg_len + rir_len - 1] += conv_sum[idx, :, : seg_len + rir_len - 1]
|
|
197
|
+
|
|
198
|
+
return out.squeeze(0) if n_mic == 1 else out
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _ensure_signal(signal: Tensor) -> Tensor:
|
|
202
|
+
"""Ensure signal has shape (n_src, n_samples)."""
|
|
203
|
+
if signal.ndim == 1:
|
|
204
|
+
return signal.unsqueeze(0)
|
|
205
|
+
if signal.ndim == 2:
|
|
206
|
+
return signal
|
|
207
|
+
raise ValueError("signal must have shape (n_samples,) or (n_src, n_samples)")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _ensure_static_rirs(rirs: Tensor) -> Tensor:
|
|
211
|
+
"""Normalize static RIR shapes to (n_src, n_mic, rir_len)."""
|
|
212
|
+
if rirs.ndim == 1:
|
|
213
|
+
return rirs.view(1, 1, -1)
|
|
214
|
+
if rirs.ndim == 2:
|
|
215
|
+
return rirs.view(1, rirs.shape[0], rirs.shape[1])
|
|
216
|
+
if rirs.ndim == 3:
|
|
217
|
+
return rirs
|
|
218
|
+
raise ValueError("rirs must have shape (rir_len,), (n_mic, rir_len), or (n_src, n_mic, rir_len)")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _ensure_dynamic_rirs(rirs: Tensor, signal: Tensor) -> Tensor:
|
|
222
|
+
"""Normalize dynamic RIR shapes to (T, n_src, n_mic, rir_len)."""
|
|
223
|
+
if rirs.ndim == 2:
|
|
224
|
+
return rirs.view(rirs.shape[0], 1, 1, rirs.shape[1])
|
|
225
|
+
if rirs.ndim == 3:
|
|
226
|
+
if signal.ndim == 2 and rirs.shape[1] == signal.shape[0]:
|
|
227
|
+
return rirs.view(rirs.shape[0], rirs.shape[1], 1, rirs.shape[2])
|
|
228
|
+
return rirs.view(rirs.shape[0], 1, rirs.shape[1], rirs.shape[2])
|
|
229
|
+
if rirs.ndim == 4:
|
|
230
|
+
return rirs
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"rirs must have shape (T, rir_len), (T, n_mic, rir_len), (T, n_src, rir_len), or (T, n_src, n_mic, rir_len)"
|
|
233
|
+
)
|
torchrir/simulators.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Simulation strategy interfaces and implementations.
|
|
4
|
+
|
|
5
|
+
Note:
|
|
6
|
+
RayTracingSimulator and FDTDSimulator are work in progress placeholders.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Protocol
|
|
11
|
+
|
|
12
|
+
from .config import SimulationConfig, default_config
|
|
13
|
+
from .core import simulate_dynamic_rir, simulate_rir
|
|
14
|
+
from .results import RIRResult
|
|
15
|
+
from .scene import Scene
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RIRSimulator(Protocol):
|
|
19
|
+
"""Strategy interface for RIR simulation backends."""
|
|
20
|
+
|
|
21
|
+
def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
|
|
22
|
+
"""Run a simulation and return the result."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class ISMSimulator:
|
|
27
|
+
"""ISM-based simulator using the current core implementation."""
|
|
28
|
+
|
|
29
|
+
def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
|
|
30
|
+
scene.validate()
|
|
31
|
+
cfg = config or default_config()
|
|
32
|
+
if scene.is_dynamic():
|
|
33
|
+
if scene.src_traj is None or scene.mic_traj is None:
|
|
34
|
+
raise ValueError("dynamic scene requires both src_traj and mic_traj")
|
|
35
|
+
rirs = simulate_dynamic_rir(
|
|
36
|
+
room=scene.room,
|
|
37
|
+
src_traj=scene.src_traj,
|
|
38
|
+
mic_traj=scene.mic_traj,
|
|
39
|
+
max_order=None,
|
|
40
|
+
nsample=None,
|
|
41
|
+
tmax=None,
|
|
42
|
+
directivity=None,
|
|
43
|
+
config=cfg,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
rirs = simulate_rir(
|
|
47
|
+
room=scene.room,
|
|
48
|
+
sources=scene.sources,
|
|
49
|
+
mics=scene.mics,
|
|
50
|
+
max_order=None,
|
|
51
|
+
nsample=None,
|
|
52
|
+
tmax=None,
|
|
53
|
+
directivity=None,
|
|
54
|
+
config=cfg,
|
|
55
|
+
)
|
|
56
|
+
return RIRResult(rirs=rirs, scene=scene, config=cfg, seed=cfg.seed)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass(frozen=True)
|
|
60
|
+
class RayTracingSimulator:
|
|
61
|
+
"""Work in progress placeholder for ray tracing simulation.
|
|
62
|
+
|
|
63
|
+
Goal:
|
|
64
|
+
Provide a geometric acoustics backend that traces specular/diffuse
|
|
65
|
+
reflection paths, supports frequency-dependent absorption/scattering,
|
|
66
|
+
and returns a RIRResult compatible with the ISM path. The intent is to
|
|
67
|
+
reuse Scene/SimulationConfig for inputs and keep output shape parity.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
|
|
71
|
+
raise NotImplementedError("RayTracingSimulator is not implemented yet")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen=True)
|
|
75
|
+
class FDTDSimulator:
|
|
76
|
+
"""Work in progress placeholder for FDTD simulation.
|
|
77
|
+
|
|
78
|
+
Goal:
|
|
79
|
+
Provide a wave-based solver (finite-difference time-domain) with
|
|
80
|
+
configurable grid resolution, boundary conditions, and stability
|
|
81
|
+
constraints. The solver should target CPU/GPU execution and return
|
|
82
|
+
RIRResult with the same metadata contract as ISM.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
|
|
86
|
+
raise NotImplementedError("FDTDSimulator is not implemented yet")
|
torchrir/utils.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Utility functions for geometry, acoustics, and tensor handling."""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import warnings
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Iterable, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_DEF_SPEED_OF_SOUND = 343.0
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def as_tensor(
|
|
18
|
+
value: Tensor | Iterable[float] | float | int,
|
|
19
|
+
*,
|
|
20
|
+
device: Optional[torch.device | str] = None,
|
|
21
|
+
dtype: Optional[torch.dtype] = None,
|
|
22
|
+
) -> Tensor:
|
|
23
|
+
"""Convert a value to a tensor while preserving device/dtype when possible."""
|
|
24
|
+
if isinstance(device, str):
|
|
25
|
+
device = resolve_device(device)
|
|
26
|
+
if torch.is_tensor(value):
|
|
27
|
+
out = value
|
|
28
|
+
if device is not None:
|
|
29
|
+
out = out.to(device)
|
|
30
|
+
if dtype is not None:
|
|
31
|
+
out = out.to(dtype)
|
|
32
|
+
return out
|
|
33
|
+
return torch.as_tensor(value, device=device, dtype=dtype)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def resolve_device(
|
|
37
|
+
device: Optional[torch.device | str],
|
|
38
|
+
*,
|
|
39
|
+
prefer: Tuple[str, ...] = ("cuda", "mps", "cpu"),
|
|
40
|
+
) -> torch.device:
|
|
41
|
+
"""Resolve a device string (including 'auto') into a torch.device.
|
|
42
|
+
|
|
43
|
+
Falls back to CPU when the requested backend is unavailable.
|
|
44
|
+
"""
|
|
45
|
+
if device is None:
|
|
46
|
+
return torch.device("cpu")
|
|
47
|
+
if isinstance(device, torch.device):
|
|
48
|
+
return device
|
|
49
|
+
|
|
50
|
+
dev = str(device).lower()
|
|
51
|
+
if dev == "auto":
|
|
52
|
+
for backend in prefer:
|
|
53
|
+
if backend == "cuda" and torch.cuda.is_available():
|
|
54
|
+
return torch.device("cuda")
|
|
55
|
+
if backend == "mps" and torch.backends.mps.is_available():
|
|
56
|
+
return torch.device("mps")
|
|
57
|
+
if backend == "cpu":
|
|
58
|
+
return torch.device("cpu")
|
|
59
|
+
return torch.device("cpu")
|
|
60
|
+
|
|
61
|
+
if dev.startswith("cuda"):
|
|
62
|
+
if torch.cuda.is_available():
|
|
63
|
+
return torch.device(device)
|
|
64
|
+
warnings.warn("CUDA not available; falling back to CPU.", RuntimeWarning)
|
|
65
|
+
return torch.device("cpu")
|
|
66
|
+
if dev == "mps":
|
|
67
|
+
if torch.backends.mps.is_available():
|
|
68
|
+
return torch.device("mps")
|
|
69
|
+
warnings.warn("MPS not available; falling back to CPU.", RuntimeWarning)
|
|
70
|
+
return torch.device("cpu")
|
|
71
|
+
if dev == "cpu":
|
|
72
|
+
return torch.device("cpu")
|
|
73
|
+
|
|
74
|
+
return torch.device(device)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class DeviceSpec:
|
|
79
|
+
"""Resolve device + dtype defaults consistently."""
|
|
80
|
+
|
|
81
|
+
device: Optional[torch.device | str] = None
|
|
82
|
+
dtype: Optional[torch.dtype] = None
|
|
83
|
+
prefer: Tuple[str, ...] = ("cuda", "mps", "cpu")
|
|
84
|
+
|
|
85
|
+
def resolve(self, *values) -> Tuple[torch.device, torch.dtype]:
|
|
86
|
+
"""Resolve device/dtype from inputs with overrides."""
|
|
87
|
+
tensor_device: Optional[torch.device] = None
|
|
88
|
+
tensor_dtype: Optional[torch.dtype] = None
|
|
89
|
+
for value in values:
|
|
90
|
+
if torch.is_tensor(value):
|
|
91
|
+
if tensor_device is None:
|
|
92
|
+
tensor_device = value.device
|
|
93
|
+
if tensor_dtype is None:
|
|
94
|
+
tensor_dtype = value.dtype
|
|
95
|
+
|
|
96
|
+
if isinstance(self.device, str) and self.device.lower() == "auto":
|
|
97
|
+
device = tensor_device or resolve_device("auto", prefer=self.prefer)
|
|
98
|
+
elif self.device is None:
|
|
99
|
+
device = tensor_device or torch.device("cpu")
|
|
100
|
+
else:
|
|
101
|
+
device = resolve_device(self.device, prefer=self.prefer)
|
|
102
|
+
|
|
103
|
+
if self.dtype is None:
|
|
104
|
+
dtype = tensor_dtype or torch.float32
|
|
105
|
+
else:
|
|
106
|
+
dtype = self.dtype
|
|
107
|
+
return device, dtype
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def infer_device_dtype(
|
|
111
|
+
*values,
|
|
112
|
+
device: Optional[torch.device | str] = None,
|
|
113
|
+
dtype: Optional[torch.dtype] = None,
|
|
114
|
+
) -> Tuple[torch.device, torch.dtype]:
|
|
115
|
+
"""Infer device/dtype from inputs with optional overrides."""
|
|
116
|
+
return DeviceSpec(device=device, dtype=dtype).resolve(*values)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def ensure_dim(size: Tensor) -> Tensor:
|
|
120
|
+
"""Validate room size dimensionality (2D or 3D)."""
|
|
121
|
+
if size.ndim != 1 or size.numel() not in (2, 3):
|
|
122
|
+
raise ValueError("room size must be a 1D tensor of length 2 or 3")
|
|
123
|
+
return size
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def extend_size(size: Tensor, dim: int) -> Tensor:
|
|
127
|
+
"""Extend 2D room size to 3D by adding a dummy z dimension."""
|
|
128
|
+
if size.numel() == dim:
|
|
129
|
+
return size
|
|
130
|
+
if size.numel() == 2 and dim == 3:
|
|
131
|
+
pad = torch.tensor([1.0], device=size.device, dtype=size.dtype)
|
|
132
|
+
return torch.cat([size, pad])
|
|
133
|
+
raise ValueError("unsupported room dimension")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def estimate_beta_from_t60(
|
|
137
|
+
size: Tensor,
|
|
138
|
+
t60: float,
|
|
139
|
+
*,
|
|
140
|
+
device: Optional[torch.device | str] = None,
|
|
141
|
+
dtype: Optional[torch.dtype] = None,
|
|
142
|
+
) -> Tensor:
|
|
143
|
+
"""Estimate reflection coefficients from T60 using Sabine's formula."""
|
|
144
|
+
if t60 <= 0:
|
|
145
|
+
raise ValueError("t60 must be positive")
|
|
146
|
+
size = as_tensor(size, device=device, dtype=dtype)
|
|
147
|
+
size = ensure_dim(size)
|
|
148
|
+
dim = size.numel()
|
|
149
|
+
if dim == 2:
|
|
150
|
+
lx, ly = size.tolist()
|
|
151
|
+
lz = 1.0
|
|
152
|
+
volume = lx * ly * lz
|
|
153
|
+
surface = 2.0 * (lx + ly) * lz
|
|
154
|
+
alpha = 0.161 * volume / (t60 * surface)
|
|
155
|
+
alpha = max(0.0, min(alpha, 0.999))
|
|
156
|
+
beta = math.sqrt(1.0 - alpha)
|
|
157
|
+
return torch.full((4,), beta, device=size.device, dtype=size.dtype)
|
|
158
|
+
size = extend_size(size, 3)
|
|
159
|
+
lx, ly, lz = size.tolist()
|
|
160
|
+
volume = lx * ly * lz
|
|
161
|
+
surface = 2.0 * (lx * ly + ly * lz + lx * lz)
|
|
162
|
+
alpha = 0.161 * volume / (t60 * surface)
|
|
163
|
+
alpha = max(0.0, min(alpha, 0.999))
|
|
164
|
+
beta = math.sqrt(1.0 - alpha)
|
|
165
|
+
return torch.full((6,), beta, device=size.device, dtype=size.dtype)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def estimate_t60_from_beta(
|
|
169
|
+
size: Tensor,
|
|
170
|
+
beta: Tensor,
|
|
171
|
+
*,
|
|
172
|
+
device: Optional[torch.device | str] = None,
|
|
173
|
+
dtype: Optional[torch.dtype] = None,
|
|
174
|
+
) -> float:
|
|
175
|
+
"""Estimate T60 from reflection coefficients using Sabine's formula."""
|
|
176
|
+
size = as_tensor(size, device=device, dtype=dtype)
|
|
177
|
+
size = ensure_dim(size)
|
|
178
|
+
beta = as_tensor(beta, device=size.device, dtype=size.dtype)
|
|
179
|
+
dim = size.numel()
|
|
180
|
+
if dim == 2:
|
|
181
|
+
if beta.numel() != 4:
|
|
182
|
+
raise ValueError("beta must have 4 elements for 2D t60 estimation")
|
|
183
|
+
lx, ly = size.tolist()
|
|
184
|
+
lz = 1.0
|
|
185
|
+
volume = lx * ly * lz
|
|
186
|
+
surfaces = torch.tensor(
|
|
187
|
+
[ly * lz, ly * lz, lx * lz, lx * lz],
|
|
188
|
+
device=size.device,
|
|
189
|
+
dtype=size.dtype,
|
|
190
|
+
)
|
|
191
|
+
alpha = 1.0 - beta**2
|
|
192
|
+
absorption = torch.sum(surfaces * alpha).item()
|
|
193
|
+
if absorption <= 0.0:
|
|
194
|
+
return float("inf")
|
|
195
|
+
return 0.161 * volume / absorption
|
|
196
|
+
size = extend_size(size, 3)
|
|
197
|
+
if beta.numel() != 6:
|
|
198
|
+
raise ValueError("beta must have 6 elements for t60 estimation")
|
|
199
|
+
lx, ly, lz = size.tolist()
|
|
200
|
+
volume = lx * ly * lz
|
|
201
|
+
surfaces = torch.tensor(
|
|
202
|
+
[ly * lz, ly * lz, lx * lz, lx * lz, lx * ly, lx * ly],
|
|
203
|
+
device=size.device,
|
|
204
|
+
dtype=size.dtype,
|
|
205
|
+
)
|
|
206
|
+
alpha = 1.0 - beta**2
|
|
207
|
+
absorption = torch.sum(surfaces * alpha).item()
|
|
208
|
+
if absorption <= 0.0:
|
|
209
|
+
return float("inf")
|
|
210
|
+
return 0.161 * volume / absorption
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def orientation_to_unit(orientation: Tensor, dim: int) -> Tensor:
|
|
214
|
+
"""Convert orientation representation to unit vectors in 2D/3D."""
|
|
215
|
+
if dim == 2:
|
|
216
|
+
if orientation.ndim == 0:
|
|
217
|
+
angle = orientation
|
|
218
|
+
vec = torch.stack([torch.cos(angle), torch.sin(angle)])
|
|
219
|
+
return normalize_orientation(vec)
|
|
220
|
+
if orientation.shape[-1] == 1:
|
|
221
|
+
angle = orientation.squeeze(-1)
|
|
222
|
+
vec = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
|
|
223
|
+
return normalize_orientation(vec)
|
|
224
|
+
if orientation.ndim == 1 and orientation.numel() != 2:
|
|
225
|
+
angle = orientation
|
|
226
|
+
vec = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
|
|
227
|
+
return normalize_orientation(vec)
|
|
228
|
+
if orientation.shape[-1] == 2:
|
|
229
|
+
return normalize_orientation(orientation)
|
|
230
|
+
raise ValueError("2D orientation must be angle or 2D vector")
|
|
231
|
+
if dim == 3:
|
|
232
|
+
if orientation.shape[-1] == 3:
|
|
233
|
+
return normalize_orientation(orientation)
|
|
234
|
+
if orientation.shape[-1] == 2:
|
|
235
|
+
azimuth = orientation[..., 0]
|
|
236
|
+
elevation = orientation[..., 1]
|
|
237
|
+
x = torch.cos(elevation) * torch.cos(azimuth)
|
|
238
|
+
y = torch.cos(elevation) * torch.sin(azimuth)
|
|
239
|
+
z = torch.sin(elevation)
|
|
240
|
+
vec = torch.stack([x, y, z], dim=-1)
|
|
241
|
+
return normalize_orientation(vec)
|
|
242
|
+
raise ValueError("3D orientation must be vector or (azimuth, elevation)")
|
|
243
|
+
raise ValueError("unsupported dimension for orientation")
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def att2t_sabine_estimation(att_db: float, t60: float) -> float:
|
|
247
|
+
"""Convert attenuation (dB) to time based on T60."""
|
|
248
|
+
if t60 <= 0:
|
|
249
|
+
raise ValueError("t60 must be positive")
|
|
250
|
+
if att_db <= 0:
|
|
251
|
+
raise ValueError("att_db must be positive")
|
|
252
|
+
return (att_db / 60.0) * t60
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def att2t_SabineEstimation(att_db: float, t60: float) -> float:
|
|
256
|
+
"""Legacy alias for att2t_sabine_estimation."""
|
|
257
|
+
return att2t_sabine_estimation(att_db, t60)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def beta_SabineEstimation(room_size: Tensor, t60: float) -> Tensor:
|
|
261
|
+
"""Legacy alias for estimate_beta_from_t60."""
|
|
262
|
+
return estimate_beta_from_t60(room_size, t60)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def t2n(tmax: float, room_size: Tensor, c: float = _DEF_SPEED_OF_SOUND) -> Tensor:
|
|
266
|
+
"""Estimate image counts per dimension needed to cover tmax."""
|
|
267
|
+
if tmax <= 0:
|
|
268
|
+
raise ValueError("tmax must be positive")
|
|
269
|
+
size = as_tensor(room_size)
|
|
270
|
+
size = ensure_dim(size)
|
|
271
|
+
# number of images in each dimension needed to cover the maximum distance
|
|
272
|
+
# uses the same heuristic as gpuRIR: n = ceil(tmax * c / room_size)
|
|
273
|
+
n = torch.ceil((tmax * c) / size).to(torch.int64)
|
|
274
|
+
return n
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def normalize_orientation(orientation: Tensor, *, eps: float = 1e-8) -> Tensor:
|
|
278
|
+
"""Normalize orientation vectors with numerical stability."""
|
|
279
|
+
norm = torch.linalg.norm(orientation, dim=-1, keepdim=True)
|
|
280
|
+
norm = torch.clamp(norm, min=eps)
|
|
281
|
+
return orientation / norm
|