torchrir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +5 -0
- torchrir/animation.py +172 -0
- torchrir/config.py +11 -2
- torchrir/core.py +30 -0
- torchrir/datasets/cmu_arctic.py +17 -2
- torchrir/datasets/utils.py +20 -2
- torchrir/dynamic.py +11 -2
- torchrir/logging_utils.py +17 -3
- torchrir/metadata.py +216 -0
- torchrir/plotting.py +113 -20
- torchrir/plotting_utils.py +15 -30
- torchrir/results.py +7 -1
- torchrir/room.py +30 -6
- torchrir/scene.py +6 -1
- torchrir/scene_utils.py +22 -4
- torchrir/signal.py +6 -0
- torchrir/simulators.py +5 -1
- torchrir/utils.py +39 -7
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/METADATA +60 -2
- torchrir-0.1.2.dist-info/RECORD +28 -0
- torchrir-0.1.0.dist-info/RECORD +0 -26
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/WHEEL +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/licenses/NOTICE +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/top_level.txt +0 -0
torchrir/__init__.py
CHANGED
|
@@ -4,6 +4,8 @@ from .config import SimulationConfig, default_config
|
|
|
4
4
|
from .core import simulate_dynamic_rir, simulate_rir
|
|
5
5
|
from .dynamic import DynamicConvolver
|
|
6
6
|
from .logging_utils import LoggingConfig, get_logger, setup_logging
|
|
7
|
+
from .animation import animate_scene_gif
|
|
8
|
+
from .metadata import build_metadata, save_metadata_json
|
|
7
9
|
from .plotting import plot_scene_dynamic, plot_scene_static
|
|
8
10
|
from .plotting_utils import plot_scene_and_save
|
|
9
11
|
from .room import MicrophoneArray, Room, Source
|
|
@@ -61,6 +63,8 @@ __all__ = [
|
|
|
61
63
|
"get_logger",
|
|
62
64
|
"list_cmu_arctic_speakers",
|
|
63
65
|
"LoggingConfig",
|
|
66
|
+
"animate_scene_gif",
|
|
67
|
+
"build_metadata",
|
|
64
68
|
"resolve_device",
|
|
65
69
|
"SentenceLike",
|
|
66
70
|
"load_dataset_sources",
|
|
@@ -75,6 +79,7 @@ __all__ = [
|
|
|
75
79
|
"plot_scene_and_save",
|
|
76
80
|
"plot_scene_static",
|
|
77
81
|
"save_wav",
|
|
82
|
+
"save_metadata_json",
|
|
78
83
|
"Scene",
|
|
79
84
|
"setup_logging",
|
|
80
85
|
"SimulationConfig",
|
torchrir/animation.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Animation helpers for dynamic scenes."""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Sequence
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .plotting_utils import _positions_to_cpu, _to_cpu, _traj_steps, _trajectory_to_cpu
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def animate_scene_gif(
|
|
14
|
+
*,
|
|
15
|
+
out_path: Path,
|
|
16
|
+
room: Sequence[float] | torch.Tensor,
|
|
17
|
+
sources: object | torch.Tensor | Sequence,
|
|
18
|
+
mics: object | torch.Tensor | Sequence,
|
|
19
|
+
src_traj: Optional[torch.Tensor | Sequence] = None,
|
|
20
|
+
mic_traj: Optional[torch.Tensor | Sequence] = None,
|
|
21
|
+
step: int = 1,
|
|
22
|
+
fps: Optional[float] = None,
|
|
23
|
+
signal_len: Optional[int] = None,
|
|
24
|
+
fs: Optional[float] = None,
|
|
25
|
+
duration_s: Optional[float] = None,
|
|
26
|
+
plot_2d: bool = True,
|
|
27
|
+
plot_3d: bool = False,
|
|
28
|
+
) -> Path:
|
|
29
|
+
"""Render a GIF showing source/mic trajectories.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
out_path: Destination GIF path.
|
|
33
|
+
room: Room size tensor or sequence.
|
|
34
|
+
sources: Source positions or Source-like object.
|
|
35
|
+
mics: Microphone positions or MicrophoneArray-like object.
|
|
36
|
+
src_traj: Optional source trajectory (T, n_src, dim).
|
|
37
|
+
mic_traj: Optional mic trajectory (T, n_mic, dim).
|
|
38
|
+
step: Subsampling step for trajectories.
|
|
39
|
+
fps: Frames per second for the GIF (auto if None).
|
|
40
|
+
signal_len: Optional signal length (samples) to infer elapsed time.
|
|
41
|
+
fs: Sample rate used with signal_len.
|
|
42
|
+
duration_s: Optional total duration in seconds (overrides signal_len/fs).
|
|
43
|
+
plot_2d: Use 2D projection if True.
|
|
44
|
+
plot_3d: Use 3D projection if True and dim == 3.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The output path.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> animate_scene_gif(
|
|
51
|
+
... out_path=Path("outputs/scene.gif"),
|
|
52
|
+
... room=[6.0, 4.0, 3.0],
|
|
53
|
+
... sources=[[1.0, 2.0, 1.5]],
|
|
54
|
+
... mics=[[2.0, 2.0, 1.5]],
|
|
55
|
+
... src_traj=src_traj,
|
|
56
|
+
... mic_traj=mic_traj,
|
|
57
|
+
... signal_len=16000,
|
|
58
|
+
... fs=16000,
|
|
59
|
+
... )
|
|
60
|
+
"""
|
|
61
|
+
import matplotlib.pyplot as plt
|
|
62
|
+
from matplotlib import animation
|
|
63
|
+
|
|
64
|
+
out_path = Path(out_path)
|
|
65
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
66
|
+
|
|
67
|
+
room_size = _to_cpu(room)
|
|
68
|
+
src_pos = _positions_to_cpu(sources)
|
|
69
|
+
mic_pos = _positions_to_cpu(mics)
|
|
70
|
+
dim = int(room_size.numel())
|
|
71
|
+
view_dim = 3 if (plot_3d and dim == 3) else 2
|
|
72
|
+
view_room = room_size[:view_dim]
|
|
73
|
+
view_src = src_pos[:, :view_dim]
|
|
74
|
+
view_mic = mic_pos[:, :view_dim]
|
|
75
|
+
|
|
76
|
+
if src_traj is None and mic_traj is None:
|
|
77
|
+
raise ValueError("at least one trajectory is required for animation")
|
|
78
|
+
steps = _traj_steps(src_traj, mic_traj)
|
|
79
|
+
src_traj = _trajectory_to_cpu(src_traj, src_pos, steps)
|
|
80
|
+
mic_traj = _trajectory_to_cpu(mic_traj, mic_pos, steps)
|
|
81
|
+
view_src_traj = src_traj[:, :, :view_dim]
|
|
82
|
+
view_mic_traj = mic_traj[:, :, :view_dim]
|
|
83
|
+
|
|
84
|
+
if view_dim == 3:
|
|
85
|
+
fig = plt.figure()
|
|
86
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
87
|
+
ax.set_xlim(0, view_room[0].item())
|
|
88
|
+
ax.set_ylim(0, view_room[1].item())
|
|
89
|
+
ax.set_zlim(0, view_room[2].item())
|
|
90
|
+
ax.set_xlabel("x")
|
|
91
|
+
ax.set_ylabel("y")
|
|
92
|
+
ax.set_zlabel("z")
|
|
93
|
+
else:
|
|
94
|
+
fig, ax = plt.subplots()
|
|
95
|
+
ax.set_xlim(0, view_room[0].item())
|
|
96
|
+
ax.set_ylim(0, view_room[1].item())
|
|
97
|
+
ax.set_aspect("equal", adjustable="box")
|
|
98
|
+
ax.set_xlabel("x")
|
|
99
|
+
ax.set_ylabel("y")
|
|
100
|
+
|
|
101
|
+
src_scatter = ax.scatter([], [], marker="^", color="tab:green", label="sources")
|
|
102
|
+
mic_scatter = ax.scatter([], [], marker="o", color="tab:orange", label="mics")
|
|
103
|
+
src_lines = []
|
|
104
|
+
mic_lines = []
|
|
105
|
+
for _ in range(view_src_traj.shape[1]):
|
|
106
|
+
if view_dim == 2:
|
|
107
|
+
line, = ax.plot([], [], color="tab:green", alpha=0.6)
|
|
108
|
+
else:
|
|
109
|
+
line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
|
|
110
|
+
src_lines.append(line)
|
|
111
|
+
for _ in range(view_mic_traj.shape[1]):
|
|
112
|
+
if view_dim == 2:
|
|
113
|
+
line, = ax.plot([], [], color="tab:orange", alpha=0.6)
|
|
114
|
+
else:
|
|
115
|
+
line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
|
|
116
|
+
mic_lines.append(line)
|
|
117
|
+
|
|
118
|
+
ax.legend(loc="best")
|
|
119
|
+
|
|
120
|
+
if duration_s is None and signal_len is not None and fs is not None:
|
|
121
|
+
duration_s = float(signal_len) / float(fs)
|
|
122
|
+
|
|
123
|
+
def _frame(i: int):
|
|
124
|
+
idx = min(i * step, view_src_traj.shape[0] - 1)
|
|
125
|
+
src_frame = view_src_traj[: idx + 1]
|
|
126
|
+
mic_frame = view_mic_traj[: idx + 1]
|
|
127
|
+
src_pos_frame = view_src_traj[idx]
|
|
128
|
+
mic_pos_frame = view_mic_traj[idx]
|
|
129
|
+
|
|
130
|
+
if view_dim == 2:
|
|
131
|
+
src_scatter.set_offsets(src_pos_frame)
|
|
132
|
+
mic_scatter.set_offsets(mic_pos_frame)
|
|
133
|
+
for s_idx, line in enumerate(src_lines):
|
|
134
|
+
xy = src_frame[:, s_idx, :]
|
|
135
|
+
line.set_data(xy[:, 0], xy[:, 1])
|
|
136
|
+
for m_idx, line in enumerate(mic_lines):
|
|
137
|
+
xy = mic_frame[:, m_idx, :]
|
|
138
|
+
line.set_data(xy[:, 0], xy[:, 1])
|
|
139
|
+
else:
|
|
140
|
+
src_scatter._offsets3d = (
|
|
141
|
+
src_pos_frame[:, 0],
|
|
142
|
+
src_pos_frame[:, 1],
|
|
143
|
+
src_pos_frame[:, 2],
|
|
144
|
+
)
|
|
145
|
+
mic_scatter._offsets3d = (
|
|
146
|
+
mic_pos_frame[:, 0],
|
|
147
|
+
mic_pos_frame[:, 1],
|
|
148
|
+
mic_pos_frame[:, 2],
|
|
149
|
+
)
|
|
150
|
+
for s_idx, line in enumerate(src_lines):
|
|
151
|
+
xyz = src_frame[:, s_idx, :]
|
|
152
|
+
line.set_data(xyz[:, 0], xyz[:, 1])
|
|
153
|
+
line.set_3d_properties(xyz[:, 2])
|
|
154
|
+
for m_idx, line in enumerate(mic_lines):
|
|
155
|
+
xyz = mic_frame[:, m_idx, :]
|
|
156
|
+
line.set_data(xyz[:, 0], xyz[:, 1])
|
|
157
|
+
line.set_3d_properties(xyz[:, 2])
|
|
158
|
+
if duration_s is not None and steps > 1:
|
|
159
|
+
t = (idx / (steps - 1)) * duration_s
|
|
160
|
+
ax.set_title(f"t = {t:.2f} s")
|
|
161
|
+
return [src_scatter, mic_scatter, *src_lines, *mic_lines]
|
|
162
|
+
|
|
163
|
+
frames = max(1, (view_src_traj.shape[0] + step - 1) // step)
|
|
164
|
+
if fps is None or fps <= 0:
|
|
165
|
+
if duration_s is not None and duration_s > 0:
|
|
166
|
+
fps = frames / duration_s
|
|
167
|
+
else:
|
|
168
|
+
fps = 6.0
|
|
169
|
+
anim = animation.FuncAnimation(fig, _frame, frames=frames, interval=1000 / fps, blit=False)
|
|
170
|
+
anim.save(out_path, writer="pillow", fps=fps)
|
|
171
|
+
plt.close(fig)
|
|
172
|
+
return out_path
|
torchrir/config.py
CHANGED
|
@@ -10,7 +10,12 @@ import torch
|
|
|
10
10
|
|
|
11
11
|
@dataclass(frozen=True)
|
|
12
12
|
class SimulationConfig:
|
|
13
|
-
"""Configuration values for RIR simulation and convolution.
|
|
13
|
+
"""Configuration values for RIR simulation and convolution.
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> cfg = SimulationConfig(max_order=6, tmax=0.3, device="auto")
|
|
17
|
+
>>> cfg.validate()
|
|
18
|
+
"""
|
|
14
19
|
|
|
15
20
|
fs: Optional[float] = None
|
|
16
21
|
max_order: Optional[int] = None
|
|
@@ -53,7 +58,11 @@ class SimulationConfig:
|
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
def default_config() -> SimulationConfig:
|
|
56
|
-
"""Return the default simulation configuration.
|
|
61
|
+
"""Return the default simulation configuration.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
>>> cfg = default_config()
|
|
65
|
+
"""
|
|
57
66
|
cfg = SimulationConfig()
|
|
58
67
|
cfg.validate()
|
|
59
68
|
return cfg
|
torchrir/core.py
CHANGED
|
@@ -58,6 +58,18 @@ def simulate_rir(
|
|
|
58
58
|
|
|
59
59
|
Returns:
|
|
60
60
|
Tensor of shape (n_src, n_mic, nsample).
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
64
|
+
>>> sources = Source.positions([[1.0, 2.0, 1.5]])
|
|
65
|
+
>>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
|
|
66
|
+
>>> rir = simulate_rir(
|
|
67
|
+
... room=room,
|
|
68
|
+
... sources=sources,
|
|
69
|
+
... mics=mics,
|
|
70
|
+
... max_order=6,
|
|
71
|
+
... tmax=0.3,
|
|
72
|
+
... )
|
|
61
73
|
"""
|
|
62
74
|
cfg = config or default_config()
|
|
63
75
|
cfg.validate()
|
|
@@ -208,6 +220,24 @@ def simulate_dynamic_rir(
|
|
|
208
220
|
|
|
209
221
|
Returns:
|
|
210
222
|
Tensor of shape (T, n_src, n_mic, nsample).
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
226
|
+
>>> from torchrir import linear_trajectory
|
|
227
|
+
>>> src_traj = torch.stack(
|
|
228
|
+
... [linear_trajectory(torch.tensor([1.0, 2.0, 1.5]),
|
|
229
|
+
... torch.tensor([4.0, 2.0, 1.5]), 8)],
|
|
230
|
+
... dim=1,
|
|
231
|
+
... )
|
|
232
|
+
>>> mic_pos = torch.tensor([[2.0, 2.0, 1.5]])
|
|
233
|
+
>>> mic_traj = mic_pos.unsqueeze(0).repeat(8, 1, 1)
|
|
234
|
+
>>> rirs = simulate_dynamic_rir(
|
|
235
|
+
... room=room,
|
|
236
|
+
... src_traj=src_traj,
|
|
237
|
+
... mic_traj=mic_traj,
|
|
238
|
+
... max_order=4,
|
|
239
|
+
... tmax=0.3,
|
|
240
|
+
... )
|
|
211
241
|
"""
|
|
212
242
|
cfg = config or default_config()
|
|
213
243
|
cfg.validate()
|
torchrir/datasets/cmu_arctic.py
CHANGED
|
@@ -49,6 +49,13 @@ class CmuArcticSentence:
|
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
class CmuArcticDataset:
|
|
52
|
+
"""CMU ARCTIC dataset loader.
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
>>> dataset = CmuArcticDataset(Path("datasets/cmu_arctic"), speaker="bdl", download=True)
|
|
56
|
+
>>> audio, fs = dataset.load_wav("arctic_a0001")
|
|
57
|
+
"""
|
|
58
|
+
|
|
52
59
|
def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
|
|
53
60
|
"""Initialize a CMU ARCTIC dataset handle.
|
|
54
61
|
|
|
@@ -182,7 +189,11 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
|
|
|
182
189
|
|
|
183
190
|
|
|
184
191
|
def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
185
|
-
"""Load a wav file and return mono audio and sample rate.
|
|
192
|
+
"""Load a wav file and return mono audio and sample rate.
|
|
193
|
+
|
|
194
|
+
Example:
|
|
195
|
+
>>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
|
|
196
|
+
"""
|
|
186
197
|
import soundfile as sf
|
|
187
198
|
|
|
188
199
|
audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
|
|
@@ -195,7 +206,11 @@ def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
|
195
206
|
|
|
196
207
|
|
|
197
208
|
def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
|
|
198
|
-
"""Save a mono or multi-channel wav to disk.
|
|
209
|
+
"""Save a mono or multi-channel wav to disk.
|
|
210
|
+
|
|
211
|
+
Example:
|
|
212
|
+
>>> save_wav(Path("outputs/example.wav"), audio, sample_rate)
|
|
213
|
+
"""
|
|
199
214
|
import soundfile as sf
|
|
200
215
|
|
|
201
216
|
audio = audio.detach().cpu().clamp(-1.0, 1.0).to(torch.float32)
|
torchrir/datasets/utils.py
CHANGED
|
@@ -11,7 +11,12 @@ from .base import BaseDataset, SentenceLike
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
|
|
14
|
-
"""Select unique speakers for the requested number of sources.
|
|
14
|
+
"""Select unique speakers for the requested number of sources.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> rng = random.Random(0)
|
|
18
|
+
>>> speakers = choose_speakers(dataset, num_sources=2, rng=rng)
|
|
19
|
+
"""
|
|
15
20
|
speakers = dataset.list_speakers()
|
|
16
21
|
if not speakers:
|
|
17
22
|
raise RuntimeError("no speakers available")
|
|
@@ -27,7 +32,20 @@ def load_dataset_sources(
|
|
|
27
32
|
duration_s: float,
|
|
28
33
|
rng: random.Random,
|
|
29
34
|
) -> Tuple[torch.Tensor, int, List[Tuple[str, List[str]]]]:
|
|
30
|
-
"""Load and concatenate utterances for each speaker into fixed-length signals.
|
|
35
|
+
"""Load and concatenate utterances for each speaker into fixed-length signals.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> from pathlib import Path
|
|
39
|
+
>>> from torchrir import CmuArcticDataset
|
|
40
|
+
>>> rng = random.Random(0)
|
|
41
|
+
>>> root = Path("datasets/cmu_arctic")
|
|
42
|
+
>>> signals, fs, info = load_dataset_sources(
|
|
43
|
+
... dataset_factory=lambda spk: CmuArcticDataset(root, speaker=spk, download=True),
|
|
44
|
+
... num_sources=2,
|
|
45
|
+
... duration_s=10.0,
|
|
46
|
+
... rng=rng,
|
|
47
|
+
... )
|
|
48
|
+
"""
|
|
31
49
|
dataset0 = dataset_factory(None)
|
|
32
50
|
speakers = choose_speakers(dataset0, num_sources, rng)
|
|
33
51
|
signals: List[torch.Tensor] = []
|
torchrir/dynamic.py
CHANGED
|
@@ -17,7 +17,12 @@ from .signal import _ensure_dynamic_rirs, _ensure_signal
|
|
|
17
17
|
|
|
18
18
|
@dataclass(frozen=True)
|
|
19
19
|
class DynamicConvolver:
|
|
20
|
-
"""Convolver for time-varying RIRs.
|
|
20
|
+
"""Convolver for time-varying RIRs.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> convolver = DynamicConvolver(mode="trajectory")
|
|
24
|
+
>>> y = convolver.convolve(signal, rirs)
|
|
25
|
+
"""
|
|
21
26
|
|
|
22
27
|
mode: str = "trajectory"
|
|
23
28
|
hop: Optional[int] = None
|
|
@@ -28,7 +33,11 @@ class DynamicConvolver:
|
|
|
28
33
|
return self.convolve(signal, rirs)
|
|
29
34
|
|
|
30
35
|
def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
|
|
31
|
-
"""Convolve signals with time-varying RIRs.
|
|
36
|
+
"""Convolve signals with time-varying RIRs.
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
>>> y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
|
|
40
|
+
"""
|
|
32
41
|
if self.mode not in ("trajectory", "hop"):
|
|
33
42
|
raise ValueError("mode must be 'trajectory' or 'hop'")
|
|
34
43
|
if self.mode == "hop":
|
torchrir/logging_utils.py
CHANGED
|
@@ -9,7 +9,12 @@ from typing import Optional
|
|
|
9
9
|
|
|
10
10
|
@dataclass(frozen=True)
|
|
11
11
|
class LoggingConfig:
|
|
12
|
-
"""Configuration for torchrir logging.
|
|
12
|
+
"""Configuration for torchrir logging.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
>>> config = LoggingConfig(level="INFO")
|
|
16
|
+
>>> logger = setup_logging(config)
|
|
17
|
+
"""
|
|
13
18
|
|
|
14
19
|
level: str | int = "INFO"
|
|
15
20
|
format: str = "%(levelname)s:%(name)s:%(message)s"
|
|
@@ -33,7 +38,12 @@ class LoggingConfig:
|
|
|
33
38
|
|
|
34
39
|
|
|
35
40
|
def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.Logger:
|
|
36
|
-
"""Configure and return the base torchrir logger.
|
|
41
|
+
"""Configure and return the base torchrir logger.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> logger = setup_logging(LoggingConfig(level="DEBUG"))
|
|
45
|
+
>>> logger.info("ready")
|
|
46
|
+
"""
|
|
37
47
|
logger = logging.getLogger(name)
|
|
38
48
|
level = config.resolve_level()
|
|
39
49
|
logger.setLevel(level)
|
|
@@ -47,7 +57,11 @@ def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.L
|
|
|
47
57
|
|
|
48
58
|
|
|
49
59
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|
50
|
-
"""Return a torchrir logger, namespaced under the torchrir root.
|
|
60
|
+
"""Return a torchrir logger, namespaced under the torchrir root.
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
>>> logger = get_logger("examples.static")
|
|
64
|
+
"""
|
|
51
65
|
if not name:
|
|
52
66
|
return logging.getLogger("torchrir")
|
|
53
67
|
if name.startswith("torchrir"):
|
torchrir/metadata.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Metadata helpers for simulation outputs."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
from .room import MicrophoneArray, Room, Source
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class ArrayAttributes:
|
|
18
|
+
"""Structured description of a microphone array."""
|
|
19
|
+
|
|
20
|
+
geometry_name: str
|
|
21
|
+
positions: Tensor
|
|
22
|
+
orientation: Optional[Tensor]
|
|
23
|
+
center: Tensor
|
|
24
|
+
normal: Optional[Tensor]
|
|
25
|
+
spacing: Optional[float]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def build_metadata(
|
|
29
|
+
*,
|
|
30
|
+
room: Room,
|
|
31
|
+
sources: Source,
|
|
32
|
+
mics: MicrophoneArray,
|
|
33
|
+
rirs: Tensor,
|
|
34
|
+
src_traj: Optional[Tensor] = None,
|
|
35
|
+
mic_traj: Optional[Tensor] = None,
|
|
36
|
+
timestamps: Optional[Tensor] = None,
|
|
37
|
+
signal_len: Optional[int] = None,
|
|
38
|
+
source_info: Optional[Any] = None,
|
|
39
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
40
|
+
) -> Dict[str, Any]:
|
|
41
|
+
"""Build JSON-serializable metadata for a simulation output.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> metadata = build_metadata(
|
|
45
|
+
... room=room,
|
|
46
|
+
... sources=sources,
|
|
47
|
+
... mics=mics,
|
|
48
|
+
... rirs=rirs,
|
|
49
|
+
... src_traj=src_traj,
|
|
50
|
+
... mic_traj=mic_traj,
|
|
51
|
+
... signal_len=signal.shape[-1],
|
|
52
|
+
... )
|
|
53
|
+
>>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
|
|
54
|
+
"""
|
|
55
|
+
nsample = int(rirs.shape[-1])
|
|
56
|
+
fs = float(room.fs)
|
|
57
|
+
time_axis = {
|
|
58
|
+
"fs": fs,
|
|
59
|
+
"nsample": nsample,
|
|
60
|
+
"t": _to_serializable(torch.arange(nsample, dtype=torch.float32) / fs),
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
src_pos = sources.positions
|
|
64
|
+
mic_pos = mics.positions
|
|
65
|
+
dim = int(room.size.numel())
|
|
66
|
+
src_traj_n = _normalize_traj(src_traj, src_pos, dim, "src_traj")
|
|
67
|
+
mic_traj_n = _normalize_traj(mic_traj, mic_pos, dim, "mic_traj")
|
|
68
|
+
|
|
69
|
+
t_steps = max(src_traj_n.shape[0], mic_traj_n.shape[0])
|
|
70
|
+
if src_traj_n.shape[0] == 1 and t_steps > 1:
|
|
71
|
+
src_traj_n = src_traj_n.expand(t_steps, -1, -1)
|
|
72
|
+
if mic_traj_n.shape[0] == 1 and t_steps > 1:
|
|
73
|
+
mic_traj_n = mic_traj_n.expand(t_steps, -1, -1)
|
|
74
|
+
if src_traj_n.shape[0] != mic_traj_n.shape[0]:
|
|
75
|
+
raise ValueError("src_traj and mic_traj must have matching time steps")
|
|
76
|
+
|
|
77
|
+
azimuth, elevation = _compute_doa(src_traj_n, mic_traj_n)
|
|
78
|
+
doa = {
|
|
79
|
+
"frame": "world",
|
|
80
|
+
"unit": "radians",
|
|
81
|
+
"azimuth": _to_serializable(azimuth),
|
|
82
|
+
"elevation": _to_serializable(elevation),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
timestamps_out: Optional[Tensor] = None
|
|
86
|
+
if timestamps is not None:
|
|
87
|
+
timestamps_out = timestamps
|
|
88
|
+
elif t_steps > 1 and signal_len is not None:
|
|
89
|
+
duration = max(0.0, (float(signal_len) - 1.0) / fs)
|
|
90
|
+
timestamps_out = torch.linspace(0.0, duration, t_steps, dtype=torch.float32)
|
|
91
|
+
|
|
92
|
+
array_attrs = _array_attributes(mics)
|
|
93
|
+
|
|
94
|
+
metadata: Dict[str, Any] = {
|
|
95
|
+
"room": {
|
|
96
|
+
"size": _to_serializable(room.size),
|
|
97
|
+
"c": float(room.c),
|
|
98
|
+
"beta": _to_serializable(room.beta) if room.beta is not None else None,
|
|
99
|
+
"t60": float(room.t60) if room.t60 is not None else None,
|
|
100
|
+
"fs": fs,
|
|
101
|
+
},
|
|
102
|
+
"sources": {
|
|
103
|
+
"positions": _to_serializable(src_pos),
|
|
104
|
+
"orientation": _to_serializable(sources.orientation),
|
|
105
|
+
},
|
|
106
|
+
"mics": {
|
|
107
|
+
"positions": _to_serializable(mic_pos),
|
|
108
|
+
"orientation": _to_serializable(mics.orientation),
|
|
109
|
+
},
|
|
110
|
+
"trajectories": {
|
|
111
|
+
"sources": _to_serializable(src_traj_n if t_steps > 1 else None),
|
|
112
|
+
"mics": _to_serializable(mic_traj_n if t_steps > 1 else None),
|
|
113
|
+
},
|
|
114
|
+
"array": {
|
|
115
|
+
"geometry": array_attrs.geometry_name,
|
|
116
|
+
"positions": _to_serializable(array_attrs.positions),
|
|
117
|
+
"orientation": _to_serializable(array_attrs.orientation),
|
|
118
|
+
"center": _to_serializable(array_attrs.center),
|
|
119
|
+
"normal": _to_serializable(array_attrs.normal),
|
|
120
|
+
"spacing": array_attrs.spacing,
|
|
121
|
+
},
|
|
122
|
+
"time_axis": time_axis,
|
|
123
|
+
"doa": doa,
|
|
124
|
+
"timestamps": _to_serializable(timestamps_out),
|
|
125
|
+
"rirs_shape": list(rirs.shape),
|
|
126
|
+
"dynamic": bool(t_steps > 1),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
if source_info is not None:
|
|
130
|
+
metadata["source_info"] = _to_serializable(source_info)
|
|
131
|
+
if extra:
|
|
132
|
+
metadata["extra"] = _to_serializable(extra)
|
|
133
|
+
return metadata
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def save_metadata_json(path: Path, metadata: Dict[str, Any]) -> None:
|
|
137
|
+
"""Save metadata as JSON to the given path.
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
>>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
|
|
141
|
+
"""
|
|
142
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
143
|
+
with path.open("w", encoding="utf-8") as f:
|
|
144
|
+
json.dump(metadata, f, indent=2)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _normalize_traj(traj: Optional[Tensor], pos: Tensor, dim: int, name: str) -> Tensor:
|
|
148
|
+
if traj is None:
|
|
149
|
+
if pos.ndim != 2 or pos.shape[1] != dim:
|
|
150
|
+
raise ValueError(f"{name} default positions must have shape (N, {dim})")
|
|
151
|
+
return pos.unsqueeze(0)
|
|
152
|
+
if not torch.is_tensor(traj):
|
|
153
|
+
raise TypeError(f"{name} must be a Tensor")
|
|
154
|
+
if traj.ndim == 2:
|
|
155
|
+
if traj.shape[1] != dim:
|
|
156
|
+
raise ValueError(f"{name} must have shape (T, {dim})")
|
|
157
|
+
return traj.unsqueeze(1)
|
|
158
|
+
if traj.ndim == 3:
|
|
159
|
+
if traj.shape[2] != dim:
|
|
160
|
+
raise ValueError(f"{name} must have shape (T, N, {dim})")
|
|
161
|
+
return traj
|
|
162
|
+
raise ValueError(f"{name} must have shape (T, N, {dim})")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _compute_doa(src_traj: Tensor, mic_traj: Tensor) -> tuple[Tensor, Tensor]:
|
|
166
|
+
vec = src_traj[:, :, None, :] - mic_traj[:, None, :, :]
|
|
167
|
+
x = vec[..., 0]
|
|
168
|
+
y = vec[..., 1]
|
|
169
|
+
azimuth = torch.atan2(y, x)
|
|
170
|
+
if vec.shape[-1] < 3:
|
|
171
|
+
elevation = torch.zeros_like(azimuth)
|
|
172
|
+
else:
|
|
173
|
+
z = vec[..., 2]
|
|
174
|
+
r_xy = torch.sqrt(x**2 + y**2)
|
|
175
|
+
elevation = torch.atan2(z, r_xy)
|
|
176
|
+
return azimuth, elevation
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _array_attributes(mics: MicrophoneArray) -> ArrayAttributes:
|
|
180
|
+
pos = mics.positions
|
|
181
|
+
n_mic = pos.shape[0]
|
|
182
|
+
if n_mic == 1:
|
|
183
|
+
geometry = "single"
|
|
184
|
+
elif n_mic == 2:
|
|
185
|
+
geometry = "binaural"
|
|
186
|
+
else:
|
|
187
|
+
geometry = "custom"
|
|
188
|
+
center = pos.mean(dim=0)
|
|
189
|
+
spacing = None
|
|
190
|
+
if n_mic >= 2:
|
|
191
|
+
dists = torch.cdist(pos, pos)
|
|
192
|
+
dists = dists[dists > 0]
|
|
193
|
+
if dists.numel() > 0:
|
|
194
|
+
spacing = float(dists.min().item())
|
|
195
|
+
return ArrayAttributes(
|
|
196
|
+
geometry_name=geometry,
|
|
197
|
+
positions=pos,
|
|
198
|
+
orientation=mics.orientation,
|
|
199
|
+
center=center,
|
|
200
|
+
normal=None,
|
|
201
|
+
spacing=spacing,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _to_serializable(value: Any) -> Any:
|
|
206
|
+
if value is None:
|
|
207
|
+
return None
|
|
208
|
+
if torch.is_tensor(value):
|
|
209
|
+
return value.detach().cpu().tolist()
|
|
210
|
+
if isinstance(value, Path):
|
|
211
|
+
return str(value)
|
|
212
|
+
if isinstance(value, dict):
|
|
213
|
+
return {k: _to_serializable(v) for k, v in value.items()}
|
|
214
|
+
if isinstance(value, (list, tuple)):
|
|
215
|
+
return [_to_serializable(v) for v in value]
|
|
216
|
+
return value
|