torchrir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +5 -0
- torchrir/animation.py +172 -0
- torchrir/config.py +11 -2
- torchrir/core.py +30 -0
- torchrir/datasets/cmu_arctic.py +17 -2
- torchrir/datasets/utils.py +20 -2
- torchrir/dynamic.py +11 -2
- torchrir/logging_utils.py +17 -3
- torchrir/metadata.py +216 -0
- torchrir/plotting.py +113 -20
- torchrir/plotting_utils.py +15 -30
- torchrir/results.py +7 -1
- torchrir/room.py +30 -6
- torchrir/scene.py +6 -1
- torchrir/scene_utils.py +22 -4
- torchrir/signal.py +6 -0
- torchrir/simulators.py +5 -1
- torchrir/utils.py +39 -7
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/METADATA +60 -2
- torchrir-0.1.2.dist-info/RECORD +28 -0
- torchrir-0.1.0.dist-info/RECORD +0 -26
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/WHEEL +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/licenses/NOTICE +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.2.dist-info}/top_level.txt +0 -0
torchrir/plotting.py
CHANGED
|
@@ -20,7 +20,15 @@ def plot_scene_static(
|
|
|
20
20
|
title: Optional[str] = None,
|
|
21
21
|
show: bool = False,
|
|
22
22
|
):
|
|
23
|
-
"""Plot a static room with source and mic positions.
|
|
23
|
+
"""Plot a static room with source and mic positions.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> ax = plot_scene_static(
|
|
27
|
+
... room=[6.0, 4.0, 3.0],
|
|
28
|
+
... sources=[[1.0, 2.0, 1.5]],
|
|
29
|
+
... mics=[[2.0, 2.0, 1.5]],
|
|
30
|
+
... )
|
|
31
|
+
"""
|
|
24
32
|
plt, ax = _setup_axes(ax, room)
|
|
25
33
|
|
|
26
34
|
size = _room_size(room, ax)
|
|
@@ -46,21 +54,35 @@ def plot_scene_dynamic(
|
|
|
46
54
|
src_traj: Tensor | Sequence,
|
|
47
55
|
mic_traj: Tensor | Sequence,
|
|
48
56
|
step: int = 1,
|
|
57
|
+
src_pos: Optional[Tensor | Sequence] = None,
|
|
58
|
+
mic_pos: Optional[Tensor | Sequence] = None,
|
|
49
59
|
ax: Any | None = None,
|
|
50
60
|
title: Optional[str] = None,
|
|
51
61
|
show: bool = False,
|
|
52
62
|
):
|
|
53
|
-
"""Plot source and mic trajectories within a room.
|
|
63
|
+
"""Plot source and mic trajectories within a room.
|
|
64
|
+
|
|
65
|
+
If trajectories are static, only positions are plotted.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> ax = plot_scene_dynamic(
|
|
69
|
+
... room=[6.0, 4.0, 3.0],
|
|
70
|
+
... src_traj=src_traj,
|
|
71
|
+
... mic_traj=mic_traj,
|
|
72
|
+
... )
|
|
73
|
+
"""
|
|
54
74
|
plt, ax = _setup_axes(ax, room)
|
|
55
75
|
|
|
56
76
|
size = _room_size(room, ax)
|
|
57
77
|
_draw_room(ax, size)
|
|
58
78
|
|
|
59
|
-
src_traj = _as_trajectory(src_traj
|
|
60
|
-
mic_traj = _as_trajectory(mic_traj
|
|
79
|
+
src_traj = _as_trajectory(src_traj)
|
|
80
|
+
mic_traj = _as_trajectory(mic_traj)
|
|
81
|
+
src_pos_t = _extract_positions(src_pos, ax) if src_pos is not None else src_traj[0]
|
|
82
|
+
mic_pos_t = _extract_positions(mic_pos, ax) if mic_pos is not None else mic_traj[0]
|
|
61
83
|
|
|
62
|
-
|
|
63
|
-
|
|
84
|
+
_plot_entity(ax, src_traj, src_pos_t, step=step, label="sources", marker="^")
|
|
85
|
+
_plot_entity(ax, mic_traj, mic_pos_t, step=step, label="mics", marker="o")
|
|
64
86
|
|
|
65
87
|
if title:
|
|
66
88
|
ax.set_title(title)
|
|
@@ -176,18 +198,32 @@ def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax:
|
|
|
176
198
|
return pos
|
|
177
199
|
|
|
178
200
|
|
|
179
|
-
def _scatter_positions(
|
|
201
|
+
def _scatter_positions(
|
|
202
|
+
ax: Any,
|
|
203
|
+
positions: Tensor,
|
|
204
|
+
*,
|
|
205
|
+
label: str,
|
|
206
|
+
marker: str,
|
|
207
|
+
color: Optional[str] = None,
|
|
208
|
+
) -> None:
|
|
180
209
|
"""Scatter-plot positions in 2D or 3D."""
|
|
181
210
|
if positions.numel() == 0:
|
|
182
211
|
return
|
|
183
212
|
dim = positions.shape[1]
|
|
184
213
|
if dim == 2:
|
|
185
|
-
ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker)
|
|
214
|
+
ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker, color=color)
|
|
186
215
|
else:
|
|
187
|
-
ax.scatter(
|
|
216
|
+
ax.scatter(
|
|
217
|
+
positions[:, 0],
|
|
218
|
+
positions[:, 1],
|
|
219
|
+
positions[:, 2],
|
|
220
|
+
label=label,
|
|
221
|
+
marker=marker,
|
|
222
|
+
color=color,
|
|
223
|
+
)
|
|
188
224
|
|
|
189
225
|
|
|
190
|
-
def _as_trajectory(traj: Tensor | Sequence
|
|
226
|
+
def _as_trajectory(traj: Tensor | Sequence) -> Tensor:
|
|
191
227
|
"""Validate and normalize a trajectory tensor."""
|
|
192
228
|
traj = as_tensor(traj)
|
|
193
229
|
if traj.ndim != 3:
|
|
@@ -195,16 +231,73 @@ def _as_trajectory(traj: Tensor | Sequence, ax: Any | None) -> Tensor:
|
|
|
195
231
|
return traj
|
|
196
232
|
|
|
197
233
|
|
|
198
|
-
def
|
|
199
|
-
|
|
234
|
+
def _plot_entity(
|
|
235
|
+
ax: Any,
|
|
236
|
+
traj: Tensor,
|
|
237
|
+
positions: Tensor,
|
|
238
|
+
*,
|
|
239
|
+
step: int,
|
|
240
|
+
label: str,
|
|
241
|
+
marker: str,
|
|
242
|
+
) -> None:
|
|
243
|
+
"""Plot trajectories and/or static positions with a unified legend entry."""
|
|
200
244
|
if traj.numel() == 0:
|
|
201
245
|
return
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
246
|
+
import matplotlib.pyplot as plt
|
|
247
|
+
|
|
248
|
+
if positions.shape != traj.shape[1:]:
|
|
249
|
+
positions = traj[0]
|
|
250
|
+
moving = _is_moving(traj, positions)
|
|
251
|
+
colors = plt.rcParams.get("axes.prop_cycle", None)
|
|
252
|
+
if colors is not None:
|
|
253
|
+
palette = colors.by_key().get("color", [])
|
|
207
254
|
else:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
255
|
+
palette = []
|
|
256
|
+
if not palette:
|
|
257
|
+
palette = ["C0", "C1", "C2", "C3", "C4", "C5"]
|
|
258
|
+
|
|
259
|
+
dim = traj.shape[2]
|
|
260
|
+
for idx in range(traj.shape[1]):
|
|
261
|
+
color = palette[idx % len(palette)]
|
|
262
|
+
lbl = label if idx == 0 else "_nolegend_"
|
|
263
|
+
if moving:
|
|
264
|
+
if dim == 2:
|
|
265
|
+
xy = traj[::step, idx]
|
|
266
|
+
ax.plot(
|
|
267
|
+
xy[:, 0],
|
|
268
|
+
xy[:, 1],
|
|
269
|
+
label=lbl,
|
|
270
|
+
color=color,
|
|
271
|
+
marker=marker,
|
|
272
|
+
markevery=[0],
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
xyz = traj[::step, idx]
|
|
276
|
+
ax.plot(
|
|
277
|
+
xyz[:, 0],
|
|
278
|
+
xyz[:, 1],
|
|
279
|
+
xyz[:, 2],
|
|
280
|
+
label=lbl,
|
|
281
|
+
color=color,
|
|
282
|
+
marker=marker,
|
|
283
|
+
markevery=[0],
|
|
284
|
+
)
|
|
285
|
+
pos = positions[idx : idx + 1]
|
|
286
|
+
_scatter_positions(ax, pos, label="_nolegend_", marker=marker, color=color)
|
|
287
|
+
if not moving:
|
|
288
|
+
# ensure legend marker uses the group label
|
|
289
|
+
_scatter_positions(
|
|
290
|
+
ax,
|
|
291
|
+
positions[:1],
|
|
292
|
+
label=label,
|
|
293
|
+
marker=marker,
|
|
294
|
+
color=palette[0],
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _is_moving(traj: Tensor, positions: Tensor, *, tol: float = 1e-6) -> bool:
|
|
299
|
+
"""Return True if any trajectory deviates from the provided positions."""
|
|
300
|
+
if traj.numel() == 0:
|
|
301
|
+
return False
|
|
302
|
+
pos0 = positions.unsqueeze(0).expand_as(traj)
|
|
303
|
+
return torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item()
|
torchrir/plotting_utils.py
CHANGED
|
@@ -26,6 +26,8 @@ def plot_scene_and_save(
|
|
|
26
26
|
) -> tuple[list[Path], list[Path]]:
|
|
27
27
|
"""Plot static and dynamic scenes and save images to disk.
|
|
28
28
|
|
|
29
|
+
Dynamic plots show trajectories for moving entities and points for fixed ones.
|
|
30
|
+
|
|
29
31
|
Args:
|
|
30
32
|
out_dir: Output directory for PNGs.
|
|
31
33
|
room: Room size tensor or sequence.
|
|
@@ -41,6 +43,17 @@ def plot_scene_and_save(
|
|
|
41
43
|
|
|
42
44
|
Returns:
|
|
43
45
|
Tuple of (static_paths, dynamic_paths).
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> plot_scene_and_save(
|
|
49
|
+
... out_dir=Path("outputs"),
|
|
50
|
+
... room=[6.0, 4.0, 3.0],
|
|
51
|
+
... sources=[[1.0, 2.0, 1.5]],
|
|
52
|
+
... mics=[[2.0, 2.0, 1.5]],
|
|
53
|
+
... src_traj=src_traj,
|
|
54
|
+
... mic_traj=mic_traj,
|
|
55
|
+
... prefix="scene",
|
|
56
|
+
... )
|
|
44
57
|
"""
|
|
45
58
|
out_dir = Path(out_dir)
|
|
46
59
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -85,11 +98,12 @@ def plot_scene_and_save(
|
|
|
85
98
|
room=view_room,
|
|
86
99
|
src_traj=view_src_traj,
|
|
87
100
|
mic_traj=view_mic_traj,
|
|
101
|
+
src_pos=view_src,
|
|
102
|
+
mic_pos=view_mic,
|
|
88
103
|
step=step,
|
|
89
104
|
title=f"Room scene ({view_dim}D trajectories)",
|
|
90
105
|
show=False,
|
|
91
106
|
)
|
|
92
|
-
_overlay_positions(ax, view_src, view_mic)
|
|
93
107
|
dynamic_path = out_dir / f"{prefix}_dynamic_{view_dim}d.png"
|
|
94
108
|
_save_axes(ax, dynamic_path, show=show)
|
|
95
109
|
dynamic_paths.append(dynamic_path)
|
|
@@ -142,32 +156,3 @@ def _save_axes(ax: Any, path: Path, *, show: bool) -> None:
|
|
|
142
156
|
if show:
|
|
143
157
|
plt.show()
|
|
144
158
|
plt.close(fig)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def _overlay_positions(ax: Any, sources: torch.Tensor, mics: torch.Tensor) -> None:
|
|
148
|
-
"""Overlay static source and mic positions on an axis."""
|
|
149
|
-
if sources.numel() > 0:
|
|
150
|
-
if sources.shape[1] == 2:
|
|
151
|
-
ax.scatter(sources[:, 0], sources[:, 1], marker="^", label="sources", color="tab:green")
|
|
152
|
-
else:
|
|
153
|
-
ax.scatter(
|
|
154
|
-
sources[:, 0],
|
|
155
|
-
sources[:, 1],
|
|
156
|
-
sources[:, 2],
|
|
157
|
-
marker="^",
|
|
158
|
-
label="sources",
|
|
159
|
-
color="tab:green",
|
|
160
|
-
)
|
|
161
|
-
if mics.numel() > 0:
|
|
162
|
-
if mics.shape[1] == 2:
|
|
163
|
-
ax.scatter(mics[:, 0], mics[:, 1], marker="o", label="mics", color="tab:orange")
|
|
164
|
-
else:
|
|
165
|
-
ax.scatter(
|
|
166
|
-
mics[:, 0],
|
|
167
|
-
mics[:, 1],
|
|
168
|
-
mics[:, 2],
|
|
169
|
-
marker="o",
|
|
170
|
-
label="mics",
|
|
171
|
-
color="tab:orange",
|
|
172
|
-
)
|
|
173
|
-
ax.legend(loc="best")
|
torchrir/results.py
CHANGED
|
@@ -13,7 +13,13 @@ from .scene import Scene
|
|
|
13
13
|
|
|
14
14
|
@dataclass(frozen=True)
|
|
15
15
|
class RIRResult:
|
|
16
|
-
"""Container for RIRs with metadata.
|
|
16
|
+
"""Container for RIRs with metadata.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> from torchrir import ISMSimulator
|
|
20
|
+
>>> result = ISMSimulator().simulate(scene, config)
|
|
21
|
+
>>> rirs = result.rirs
|
|
22
|
+
"""
|
|
17
23
|
|
|
18
24
|
rirs: Tensor
|
|
19
25
|
scene: Scene
|
torchrir/room.py
CHANGED
|
@@ -13,7 +13,11 @@ from .utils import as_tensor, ensure_dim
|
|
|
13
13
|
|
|
14
14
|
@dataclass(frozen=True)
|
|
15
15
|
class Room:
|
|
16
|
-
"""Room geometry and acoustic parameters.
|
|
16
|
+
"""Room geometry and acoustic parameters.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
20
|
+
"""
|
|
17
21
|
|
|
18
22
|
size: Tensor
|
|
19
23
|
fs: float
|
|
@@ -43,7 +47,11 @@ class Room:
|
|
|
43
47
|
device: Optional[torch.device | str] = None,
|
|
44
48
|
dtype: Optional[torch.dtype] = None,
|
|
45
49
|
) -> "Room":
|
|
46
|
-
"""Create a rectangular (shoebox) room.
|
|
50
|
+
"""Create a rectangular (shoebox) room.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
54
|
+
"""
|
|
47
55
|
size_t = as_tensor(size, device=device, dtype=dtype)
|
|
48
56
|
size_t = ensure_dim(size_t)
|
|
49
57
|
beta_t = None
|
|
@@ -54,7 +62,11 @@ class Room:
|
|
|
54
62
|
|
|
55
63
|
@dataclass(frozen=True)
|
|
56
64
|
class Source:
|
|
57
|
-
"""Source container with positions and optional orientation.
|
|
65
|
+
"""Source container with positions and optional orientation.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> sources = Source.positions([[1.0, 2.0, 1.5]])
|
|
69
|
+
"""
|
|
58
70
|
|
|
59
71
|
positions: Tensor
|
|
60
72
|
orientation: Optional[Tensor] = None
|
|
@@ -79,7 +91,11 @@ class Source:
|
|
|
79
91
|
device: Optional[torch.device | str] = None,
|
|
80
92
|
dtype: Optional[torch.dtype] = None,
|
|
81
93
|
) -> "Source":
|
|
82
|
-
"""Construct a Source from positions.
|
|
94
|
+
"""Construct a Source from positions.
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
>>> sources = Source.positions([[1.0, 2.0, 1.5]])
|
|
98
|
+
"""
|
|
83
99
|
return cls.from_positions(
|
|
84
100
|
positions, orientation=orientation, device=device, dtype=dtype
|
|
85
101
|
)
|
|
@@ -103,7 +119,11 @@ class Source:
|
|
|
103
119
|
|
|
104
120
|
@dataclass(frozen=True)
|
|
105
121
|
class MicrophoneArray:
|
|
106
|
-
"""Microphone array container.
|
|
122
|
+
"""Microphone array container.
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
>>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
|
|
126
|
+
"""
|
|
107
127
|
|
|
108
128
|
positions: Tensor
|
|
109
129
|
orientation: Optional[Tensor] = None
|
|
@@ -128,7 +148,11 @@ class MicrophoneArray:
|
|
|
128
148
|
device: Optional[torch.device | str] = None,
|
|
129
149
|
dtype: Optional[torch.dtype] = None,
|
|
130
150
|
) -> "MicrophoneArray":
|
|
131
|
-
"""Construct a MicrophoneArray from positions.
|
|
151
|
+
"""Construct a MicrophoneArray from positions.
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
>>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
|
|
155
|
+
"""
|
|
132
156
|
return cls.from_positions(
|
|
133
157
|
positions, orientation=orientation, device=device, dtype=dtype
|
|
134
158
|
)
|
torchrir/scene.py
CHANGED
|
@@ -13,7 +13,12 @@ from .room import MicrophoneArray, Room, Source
|
|
|
13
13
|
|
|
14
14
|
@dataclass(frozen=True)
|
|
15
15
|
class Scene:
|
|
16
|
-
"""Container for room, sources, microphones, and optional trajectories.
|
|
16
|
+
"""Container for room, sources, microphones, and optional trajectories.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> scene = Scene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
|
|
20
|
+
>>> scene.validate()
|
|
21
|
+
"""
|
|
17
22
|
|
|
18
23
|
room: Room
|
|
19
24
|
sources: Source
|
torchrir/scene_utils.py
CHANGED
|
@@ -15,7 +15,13 @@ def sample_positions(
|
|
|
15
15
|
rng: random.Random,
|
|
16
16
|
margin: float = 0.5,
|
|
17
17
|
) -> torch.Tensor:
|
|
18
|
-
"""Sample random positions within a room with a safety margin.
|
|
18
|
+
"""Sample random positions within a room with a safety margin.
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
>>> rng = random.Random(0)
|
|
22
|
+
>>> room = torch.tensor([6.0, 4.0, 3.0])
|
|
23
|
+
>>> positions = sample_positions(num=2, room_size=room, rng=rng)
|
|
24
|
+
"""
|
|
19
25
|
dim = room_size.numel()
|
|
20
26
|
low = [margin] * dim
|
|
21
27
|
high = [float(room_size[i].item()) - margin for i in range(dim)]
|
|
@@ -27,7 +33,11 @@ def sample_positions(
|
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
def linear_trajectory(start: torch.Tensor, end: torch.Tensor, steps: int) -> torch.Tensor:
|
|
30
|
-
"""Create a linear trajectory between start and end.
|
|
36
|
+
"""Create a linear trajectory between start and end.
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
>>> traj = linear_trajectory(torch.tensor([1.0, 1.0, 1.0]), torch.tensor([4.0, 2.0, 1.0]), 8)
|
|
40
|
+
"""
|
|
31
41
|
return torch.stack(
|
|
32
42
|
[start + (end - start) * t / (steps - 1) for t in range(steps)],
|
|
33
43
|
dim=0,
|
|
@@ -35,7 +45,11 @@ def linear_trajectory(start: torch.Tensor, end: torch.Tensor, steps: int) -> tor
|
|
|
35
45
|
|
|
36
46
|
|
|
37
47
|
def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.Tensor:
|
|
38
|
-
"""Create a two-mic binaural layout around a center point.
|
|
48
|
+
"""Create a two-mic binaural layout around a center point.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> mics = binaural_mic_positions(torch.tensor([2.0, 2.0, 1.5]))
|
|
52
|
+
"""
|
|
39
53
|
dim = center.numel()
|
|
40
54
|
offset_vec = torch.zeros((dim,), dtype=torch.float32)
|
|
41
55
|
offset_vec[0] = offset
|
|
@@ -45,7 +59,11 @@ def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.
|
|
|
45
59
|
|
|
46
60
|
|
|
47
61
|
def clamp_positions(positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1) -> torch.Tensor:
|
|
48
|
-
"""Clamp positions to remain inside the room with a margin.
|
|
62
|
+
"""Clamp positions to remain inside the room with a margin.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
>>> clamped = clamp_positions(positions, torch.tensor([6.0, 4.0, 3.0]))
|
|
66
|
+
"""
|
|
49
67
|
min_v = torch.full_like(room_size, margin)
|
|
50
68
|
max_v = room_size - margin
|
|
51
69
|
return torch.max(torch.min(positions, max_v), min_v)
|
torchrir/signal.py
CHANGED
|
@@ -21,6 +21,9 @@ def fft_convolve(signal: Tensor, rir: Tensor) -> Tensor:
|
|
|
21
21
|
|
|
22
22
|
Returns:
|
|
23
23
|
1D tensor of length len(signal) + len(rir) - 1.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> y = fft_convolve(signal, rir)
|
|
24
27
|
"""
|
|
25
28
|
if signal.ndim != 1 or rir.ndim != 1:
|
|
26
29
|
raise ValueError("fft_convolve expects 1D tensors")
|
|
@@ -41,6 +44,9 @@ def convolve_rir(signal: Tensor, rirs: Tensor) -> Tensor:
|
|
|
41
44
|
|
|
42
45
|
Returns:
|
|
43
46
|
(n_mic, n_samples + rir_len - 1) tensor or 1D for single mic.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> y = convolve_rir(signal, rirs)
|
|
44
50
|
"""
|
|
45
51
|
signal = _ensure_signal(signal)
|
|
46
52
|
rirs = _ensure_static_rirs(rirs)
|
torchrir/simulators.py
CHANGED
|
@@ -24,7 +24,11 @@ class RIRSimulator(Protocol):
|
|
|
24
24
|
|
|
25
25
|
@dataclass(frozen=True)
|
|
26
26
|
class ISMSimulator:
|
|
27
|
-
"""ISM-based simulator using the current core implementation.
|
|
27
|
+
"""ISM-based simulator using the current core implementation.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> result = ISMSimulator().simulate(scene, config)
|
|
31
|
+
"""
|
|
28
32
|
|
|
29
33
|
def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
|
|
30
34
|
scene.validate()
|
torchrir/utils.py
CHANGED
|
@@ -41,6 +41,9 @@ def resolve_device(
|
|
|
41
41
|
"""Resolve a device string (including 'auto') into a torch.device.
|
|
42
42
|
|
|
43
43
|
Falls back to CPU when the requested backend is unavailable.
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> device = resolve_device("auto")
|
|
44
47
|
"""
|
|
45
48
|
if device is None:
|
|
46
49
|
return torch.device("cpu")
|
|
@@ -76,7 +79,12 @@ def resolve_device(
|
|
|
76
79
|
|
|
77
80
|
@dataclass(frozen=True)
|
|
78
81
|
class DeviceSpec:
|
|
79
|
-
"""Resolve device + dtype defaults consistently.
|
|
82
|
+
"""Resolve device + dtype defaults consistently.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
>>> spec = DeviceSpec(device="auto", dtype=torch.float32)
|
|
86
|
+
>>> device, dtype = spec.resolve(tensor)
|
|
87
|
+
"""
|
|
80
88
|
|
|
81
89
|
device: Optional[torch.device | str] = None
|
|
82
90
|
dtype: Optional[torch.dtype] = None
|
|
@@ -140,7 +148,11 @@ def estimate_beta_from_t60(
|
|
|
140
148
|
device: Optional[torch.device | str] = None,
|
|
141
149
|
dtype: Optional[torch.dtype] = None,
|
|
142
150
|
) -> Tensor:
|
|
143
|
-
"""Estimate reflection coefficients from T60 using Sabine's formula.
|
|
151
|
+
"""Estimate reflection coefficients from T60 using Sabine's formula.
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
>>> beta = estimate_beta_from_t60(torch.tensor([6.0, 4.0, 3.0]), t60=0.4)
|
|
155
|
+
"""
|
|
144
156
|
if t60 <= 0:
|
|
145
157
|
raise ValueError("t60 must be positive")
|
|
146
158
|
size = as_tensor(size, device=device, dtype=dtype)
|
|
@@ -172,7 +184,11 @@ def estimate_t60_from_beta(
|
|
|
172
184
|
device: Optional[torch.device | str] = None,
|
|
173
185
|
dtype: Optional[torch.dtype] = None,
|
|
174
186
|
) -> float:
|
|
175
|
-
"""Estimate T60 from reflection coefficients using Sabine's formula.
|
|
187
|
+
"""Estimate T60 from reflection coefficients using Sabine's formula.
|
|
188
|
+
|
|
189
|
+
Example:
|
|
190
|
+
>>> t60 = estimate_t60_from_beta(torch.tensor([6.0, 4.0, 3.0]), beta=torch.full((6,), 0.9))
|
|
191
|
+
"""
|
|
176
192
|
size = as_tensor(size, device=device, dtype=dtype)
|
|
177
193
|
size = ensure_dim(size)
|
|
178
194
|
beta = as_tensor(beta, device=size.device, dtype=size.dtype)
|
|
@@ -244,7 +260,11 @@ def orientation_to_unit(orientation: Tensor, dim: int) -> Tensor:
|
|
|
244
260
|
|
|
245
261
|
|
|
246
262
|
def att2t_sabine_estimation(att_db: float, t60: float) -> float:
|
|
247
|
-
"""Convert attenuation (dB) to time based on T60.
|
|
263
|
+
"""Convert attenuation (dB) to time based on T60.
|
|
264
|
+
|
|
265
|
+
Example:
|
|
266
|
+
>>> t = att2t_sabine_estimation(att_db=60.0, t60=0.4)
|
|
267
|
+
"""
|
|
248
268
|
if t60 <= 0:
|
|
249
269
|
raise ValueError("t60 must be positive")
|
|
250
270
|
if att_db <= 0:
|
|
@@ -253,17 +273,29 @@ def att2t_sabine_estimation(att_db: float, t60: float) -> float:
|
|
|
253
273
|
|
|
254
274
|
|
|
255
275
|
def att2t_SabineEstimation(att_db: float, t60: float) -> float:
|
|
256
|
-
"""Legacy alias for att2t_sabine_estimation.
|
|
276
|
+
"""Legacy alias for att2t_sabine_estimation.
|
|
277
|
+
|
|
278
|
+
Example:
|
|
279
|
+
>>> t = att2t_SabineEstimation(att_db=60.0, t60=0.4)
|
|
280
|
+
"""
|
|
257
281
|
return att2t_sabine_estimation(att_db, t60)
|
|
258
282
|
|
|
259
283
|
|
|
260
284
|
def beta_SabineEstimation(room_size: Tensor, t60: float) -> Tensor:
|
|
261
|
-
"""Legacy alias for estimate_beta_from_t60.
|
|
285
|
+
"""Legacy alias for estimate_beta_from_t60.
|
|
286
|
+
|
|
287
|
+
Example:
|
|
288
|
+
>>> beta = beta_SabineEstimation(torch.tensor([6.0, 4.0, 3.0]), t60=0.4)
|
|
289
|
+
"""
|
|
262
290
|
return estimate_beta_from_t60(room_size, t60)
|
|
263
291
|
|
|
264
292
|
|
|
265
293
|
def t2n(tmax: float, room_size: Tensor, c: float = _DEF_SPEED_OF_SOUND) -> Tensor:
|
|
266
|
-
"""Estimate image counts per dimension needed to cover tmax.
|
|
294
|
+
"""Estimate image counts per dimension needed to cover tmax.
|
|
295
|
+
|
|
296
|
+
Example:
|
|
297
|
+
>>> nb_img = t2n(0.3, torch.tensor([6.0, 4.0, 3.0]))
|
|
298
|
+
"""
|
|
267
299
|
if tmax <= 0:
|
|
268
300
|
raise ValueError("tmax must be positive")
|
|
269
301
|
size = as_tensor(room_size)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchrir
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
|
|
5
|
+
Project-URL: Repository, https://github.com/taishi-n/torchrir
|
|
5
6
|
Requires-Python: >=3.10
|
|
6
7
|
Description-Content-Type: text/markdown
|
|
7
8
|
License-File: LICENSE
|
|
@@ -18,6 +19,22 @@ This project has been substantially assisted by AI using Codex.
|
|
|
18
19
|
## License
|
|
19
20
|
Apache-2.0. See `LICENSE` and `NOTICE`.
|
|
20
21
|
|
|
22
|
+
## Installation
|
|
23
|
+
```bash
|
|
24
|
+
pip install torchrir
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Current Capabilities
|
|
28
|
+
- ISM-based static and dynamic RIR simulation (2D/3D shoebox rooms).
|
|
29
|
+
- Directivity patterns: `omni`, `cardioid`, `hypercardioid`, `subcardioid`, `bidir` with orientation handling.
|
|
30
|
+
- Acoustic parameters: `beta` or `t60` (Sabine), optional diffuse tail via `tdiff`.
|
|
31
|
+
- Dynamic convolution via `DynamicConvolver` (`trajectory` or `hop` modes).
|
|
32
|
+
- GPU acceleration for ISM accumulation (CUDA/MPS; MPS disables LUT).
|
|
33
|
+
- Dataset utilities with CMU ARCTIC support and example pipelines.
|
|
34
|
+
- Plotting utilities for static and dynamic scenes.
|
|
35
|
+
- Metadata export helpers for time axis, DOA, and array attributes (JSON-ready).
|
|
36
|
+
- Unified CLI with JSON/YAML config and deterministic flag support.
|
|
37
|
+
|
|
21
38
|
## Example Usage
|
|
22
39
|
```bash
|
|
23
40
|
# CMU ARCTIC + static RIR (fixed sources/mics)
|
|
@@ -26,16 +43,22 @@ uv run python examples/static.py --plot
|
|
|
26
43
|
# Dynamic RIR demos
|
|
27
44
|
uv run python examples/dynamic_mic.py --plot
|
|
28
45
|
uv run python examples/dynamic_src.py --plot
|
|
46
|
+
uv run python examples/dynamic_mic.py --gif
|
|
47
|
+
uv run python examples/dynamic_src.py --gif
|
|
29
48
|
|
|
30
49
|
# Unified CLI
|
|
31
50
|
uv run python examples/cli.py --mode static --plot
|
|
32
51
|
uv run python examples/cli.py --mode dynamic_mic --plot
|
|
33
52
|
uv run python examples/cli.py --mode dynamic_src --plot
|
|
53
|
+
uv run python examples/cli.py --mode dynamic_mic --gif
|
|
54
|
+
uv run python examples/dynamic_mic.py --gif --gif-fps 12
|
|
34
55
|
|
|
35
56
|
# Config + deterministic
|
|
36
57
|
uv run python examples/cli.py --mode static --deterministic --seed 123 --config-out outputs/cli.json
|
|
37
58
|
uv run python examples/cli.py --config-in outputs/cli.json
|
|
38
59
|
```
|
|
60
|
+
GIF FPS is auto-derived from signal duration and RIR steps unless overridden with `--gif-fps`.
|
|
61
|
+
For 3D rooms, an additional `*_3d.gif` is saved.
|
|
39
62
|
YAML configs are supported when `PyYAML` is installed.
|
|
40
63
|
```bash
|
|
41
64
|
# YAML config
|
|
@@ -43,6 +66,24 @@ uv run python examples/cli.py --mode static --config-out outputs/cli.yaml
|
|
|
43
66
|
uv run python examples/cli.py --config-in outputs/cli.yaml
|
|
44
67
|
```
|
|
45
68
|
`examples/cli_example.yaml` provides a ready-to-use template.
|
|
69
|
+
Examples also save `*_metadata.json` alongside audio outputs.
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
|
|
73
|
+
|
|
74
|
+
room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
75
|
+
sources = Source.positions([[1.0, 2.0, 1.5]])
|
|
76
|
+
mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
|
|
77
|
+
|
|
78
|
+
rir = simulate_rir(
|
|
79
|
+
room=room,
|
|
80
|
+
sources=sources,
|
|
81
|
+
mics=mics,
|
|
82
|
+
max_order=6,
|
|
83
|
+
tmax=0.3,
|
|
84
|
+
device="auto",
|
|
85
|
+
)
|
|
86
|
+
```
|
|
46
87
|
|
|
47
88
|
```python
|
|
48
89
|
from torchrir import DynamicConvolver
|
|
@@ -55,6 +96,20 @@ y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
|
|
|
55
96
|
```
|
|
56
97
|
Dynamic convolution is exposed via `DynamicConvolver` only (no legacy function wrappers).
|
|
57
98
|
|
|
99
|
+
## Limitations and Potential Errors
|
|
100
|
+
- Ray tracing and FDTD simulators are placeholders and raise `NotImplementedError`.
|
|
101
|
+
- `TemplateDataset` methods are not implemented and will raise `NotImplementedError`.
|
|
102
|
+
- `simulate_rir`/`simulate_dynamic_rir` require `max_order` (or `SimulationConfig.max_order`) and either `nsample` or `tmax`.
|
|
103
|
+
- Non-`omni` directivity requires orientation; mismatched shapes raise `ValueError`.
|
|
104
|
+
- `beta` must have 4 (2D) or 6 (3D) elements; invalid sizes raise `ValueError`.
|
|
105
|
+
- `simulate_dynamic_rir` requires `src_traj` and `mic_traj` to have matching time steps.
|
|
106
|
+
- Dynamic simulation currently loops per time step; very long trajectories can be slow.
|
|
107
|
+
- MPS disables the sinc LUT path (falls back to direct sinc), which can be slower and slightly different numerically.
|
|
108
|
+
- Deterministic mode is best-effort; some backends may still be non-deterministic.
|
|
109
|
+
- YAML configs require `PyYAML`; otherwise a `ModuleNotFoundError` is raised.
|
|
110
|
+
- CMU ARCTIC downloads require network access.
|
|
111
|
+
- GIF animation output requires Pillow (via matplotlib animation writer).
|
|
112
|
+
|
|
58
113
|
### Dataset-agnostic utilities
|
|
59
114
|
```python
|
|
60
115
|
from torchrir import (
|
|
@@ -129,6 +184,7 @@ device, dtype = DeviceSpec(device="auto").resolve()
|
|
|
129
184
|
|
|
130
185
|
## References
|
|
131
186
|
- [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
|
|
187
|
+
- [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
|
|
132
188
|
- [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
|
|
133
189
|
- [das-generator](https://github.com/ehabets/das-generator)
|
|
134
190
|
- [rir-generator](https://github.com/audiolabs/rir-generator)
|
|
@@ -211,3 +267,5 @@ y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
|
|
|
211
267
|
- FDTD backend: implement `FDTDSimulator` with configurable grid resolution and boundary conditions.
|
|
212
268
|
- Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`).
|
|
213
269
|
- Enhanced acoustics: frequency-dependent absorption and more advanced diffuse tail models.
|
|
270
|
+
- Add microphone and source directivity models similar to gpuRIR/pyroomacoustics.
|
|
271
|
+
- Add regression tests comparing generated RIRs against gpuRIR outputs.
|