torchrir 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +85 -0
- torchrir/config.py +59 -0
- torchrir/core.py +741 -0
- torchrir/datasets/__init__.py +27 -0
- torchrir/datasets/base.py +27 -0
- torchrir/datasets/cmu_arctic.py +204 -0
- torchrir/datasets/template.py +65 -0
- torchrir/datasets/utils.py +74 -0
- torchrir/directivity.py +33 -0
- torchrir/dynamic.py +60 -0
- torchrir/logging_utils.py +55 -0
- torchrir/plotting.py +210 -0
- torchrir/plotting_utils.py +173 -0
- torchrir/results.py +22 -0
- torchrir/room.py +150 -0
- torchrir/scene.py +67 -0
- torchrir/scene_utils.py +51 -0
- torchrir/signal.py +233 -0
- torchrir/simulators.py +86 -0
- torchrir/utils.py +281 -0
- torchrir-0.1.0.dist-info/METADATA +213 -0
- torchrir-0.1.0.dist-info/RECORD +26 -0
- torchrir-0.1.0.dist-info/WHEEL +5 -0
- torchrir-0.1.0.dist-info/licenses/LICENSE +190 -0
- torchrir-0.1.0.dist-info/licenses/NOTICE +4 -0
- torchrir-0.1.0.dist-info/top_level.txt +1 -0
torchrir/plotting.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Matplotlib-based plotting helpers for room scenes."""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Iterable, Optional, Sequence, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from .room import MicrophoneArray, Room, Source
|
|
11
|
+
from .utils import as_tensor, ensure_dim
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def plot_scene_static(
|
|
15
|
+
*,
|
|
16
|
+
room: Room | Sequence[float] | Tensor,
|
|
17
|
+
sources: Source | Tensor | Sequence,
|
|
18
|
+
mics: MicrophoneArray | Tensor | Sequence,
|
|
19
|
+
ax: Any | None = None,
|
|
20
|
+
title: Optional[str] = None,
|
|
21
|
+
show: bool = False,
|
|
22
|
+
):
|
|
23
|
+
"""Plot a static room with source and mic positions."""
|
|
24
|
+
plt, ax = _setup_axes(ax, room)
|
|
25
|
+
|
|
26
|
+
size = _room_size(room, ax)
|
|
27
|
+
_draw_room(ax, size)
|
|
28
|
+
|
|
29
|
+
src = _extract_positions(sources, ax)
|
|
30
|
+
mic = _extract_positions(mics, ax)
|
|
31
|
+
|
|
32
|
+
_scatter_positions(ax, src, label="sources", marker="^")
|
|
33
|
+
_scatter_positions(ax, mic, label="mics", marker="o")
|
|
34
|
+
|
|
35
|
+
if title:
|
|
36
|
+
ax.set_title(title)
|
|
37
|
+
ax.legend(loc="best")
|
|
38
|
+
if show:
|
|
39
|
+
plt.show()
|
|
40
|
+
return ax
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_scene_dynamic(
|
|
44
|
+
*,
|
|
45
|
+
room: Room | Sequence[float] | Tensor,
|
|
46
|
+
src_traj: Tensor | Sequence,
|
|
47
|
+
mic_traj: Tensor | Sequence,
|
|
48
|
+
step: int = 1,
|
|
49
|
+
ax: Any | None = None,
|
|
50
|
+
title: Optional[str] = None,
|
|
51
|
+
show: bool = False,
|
|
52
|
+
):
|
|
53
|
+
"""Plot source and mic trajectories within a room."""
|
|
54
|
+
plt, ax = _setup_axes(ax, room)
|
|
55
|
+
|
|
56
|
+
size = _room_size(room, ax)
|
|
57
|
+
_draw_room(ax, size)
|
|
58
|
+
|
|
59
|
+
src_traj = _as_trajectory(src_traj, ax)
|
|
60
|
+
mic_traj = _as_trajectory(mic_traj, ax)
|
|
61
|
+
|
|
62
|
+
_plot_trajectories(ax, src_traj, step=step, label="source path")
|
|
63
|
+
_plot_trajectories(ax, mic_traj, step=step, label="mic path")
|
|
64
|
+
|
|
65
|
+
if title:
|
|
66
|
+
ax.set_title(title)
|
|
67
|
+
ax.legend(loc="best")
|
|
68
|
+
if show:
|
|
69
|
+
plt.show()
|
|
70
|
+
return ax
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _setup_axes(ax: Any | None, room: Room | Sequence[float] | Tensor) -> tuple[Any, Any]:
|
|
74
|
+
"""Create 2D/3D axes based on room dimension."""
|
|
75
|
+
import matplotlib.pyplot as plt
|
|
76
|
+
|
|
77
|
+
size = _room_size(room, ax)
|
|
78
|
+
dim = size.numel()
|
|
79
|
+
if ax is None:
|
|
80
|
+
if dim == 3:
|
|
81
|
+
fig = plt.figure()
|
|
82
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
83
|
+
else:
|
|
84
|
+
_, ax = plt.subplots()
|
|
85
|
+
return plt, ax
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _room_size(room: Room | Sequence[float] | Tensor, ax: Any | None) -> Tensor:
|
|
89
|
+
"""Normalize room size input to a 1D tensor."""
|
|
90
|
+
if isinstance(room, Room):
|
|
91
|
+
size = room.size
|
|
92
|
+
else:
|
|
93
|
+
size = room
|
|
94
|
+
size = as_tensor(size)
|
|
95
|
+
size = ensure_dim(size)
|
|
96
|
+
return size
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _draw_room(ax: Any, size: Tensor) -> None:
|
|
100
|
+
"""Draw a 2D or 3D room outline."""
|
|
101
|
+
dim = size.numel()
|
|
102
|
+
if dim == 2:
|
|
103
|
+
_draw_room_2d(ax, size)
|
|
104
|
+
else:
|
|
105
|
+
_draw_room_3d(ax, size)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _draw_room_2d(ax: Any, size: Tensor) -> None:
|
|
109
|
+
"""Draw a 2D rectangular room."""
|
|
110
|
+
import matplotlib.patches as patches
|
|
111
|
+
|
|
112
|
+
rect = patches.Rectangle((0.0, 0.0), size[0].item(), size[1].item(),
|
|
113
|
+
fill=False, edgecolor="black")
|
|
114
|
+
ax.add_patch(rect)
|
|
115
|
+
ax.set_xlim(0, size[0].item())
|
|
116
|
+
ax.set_ylim(0, size[1].item())
|
|
117
|
+
ax.set_aspect("equal", adjustable="box")
|
|
118
|
+
ax.set_xlabel("x")
|
|
119
|
+
ax.set_ylabel("y")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _draw_room_3d(ax: Any, size: Tensor) -> None:
|
|
123
|
+
"""Draw a 3D box representing the room."""
|
|
124
|
+
x, y, z = size.tolist()
|
|
125
|
+
corners = torch.tensor(
|
|
126
|
+
[
|
|
127
|
+
[0, 0, 0],
|
|
128
|
+
[x, 0, 0],
|
|
129
|
+
[x, y, 0],
|
|
130
|
+
[0, y, 0],
|
|
131
|
+
[0, 0, z],
|
|
132
|
+
[x, 0, z],
|
|
133
|
+
[x, y, z],
|
|
134
|
+
[0, y, z],
|
|
135
|
+
],
|
|
136
|
+
dtype=torch.float32,
|
|
137
|
+
)
|
|
138
|
+
edges = [
|
|
139
|
+
(0, 1),
|
|
140
|
+
(1, 2),
|
|
141
|
+
(2, 3),
|
|
142
|
+
(3, 0),
|
|
143
|
+
(4, 5),
|
|
144
|
+
(5, 6),
|
|
145
|
+
(6, 7),
|
|
146
|
+
(7, 4),
|
|
147
|
+
(0, 4),
|
|
148
|
+
(1, 5),
|
|
149
|
+
(2, 6),
|
|
150
|
+
(3, 7),
|
|
151
|
+
]
|
|
152
|
+
for a, b in edges:
|
|
153
|
+
ax.plot(
|
|
154
|
+
[corners[a, 0], corners[b, 0]],
|
|
155
|
+
[corners[a, 1], corners[b, 1]],
|
|
156
|
+
[corners[a, 2], corners[b, 2]],
|
|
157
|
+
color="black",
|
|
158
|
+
)
|
|
159
|
+
ax.set_xlim(0, x)
|
|
160
|
+
ax.set_ylim(0, y)
|
|
161
|
+
ax.set_zlim(0, z)
|
|
162
|
+
ax.set_xlabel("x")
|
|
163
|
+
ax.set_ylabel("y")
|
|
164
|
+
ax.set_zlabel("z")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None) -> Tensor:
|
|
168
|
+
"""Extract positions from Source/MicrophoneArray or raw tensor."""
|
|
169
|
+
if isinstance(entity, (Source, MicrophoneArray)):
|
|
170
|
+
pos = entity.positions
|
|
171
|
+
else:
|
|
172
|
+
pos = entity
|
|
173
|
+
pos = as_tensor(pos)
|
|
174
|
+
if pos.ndim == 1:
|
|
175
|
+
pos = pos.unsqueeze(0)
|
|
176
|
+
return pos
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _scatter_positions(ax: Any, positions: Tensor, *, label: str, marker: str) -> None:
|
|
180
|
+
"""Scatter-plot positions in 2D or 3D."""
|
|
181
|
+
if positions.numel() == 0:
|
|
182
|
+
return
|
|
183
|
+
dim = positions.shape[1]
|
|
184
|
+
if dim == 2:
|
|
185
|
+
ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker)
|
|
186
|
+
else:
|
|
187
|
+
ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], label=label, marker=marker)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _as_trajectory(traj: Tensor | Sequence, ax: Any | None) -> Tensor:
|
|
191
|
+
"""Validate and normalize a trajectory tensor."""
|
|
192
|
+
traj = as_tensor(traj)
|
|
193
|
+
if traj.ndim != 3:
|
|
194
|
+
raise ValueError("trajectory must be of shape (T, N, dim)")
|
|
195
|
+
return traj
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _plot_trajectories(ax: Any, traj: Tensor, *, step: int, label: str) -> None:
|
|
199
|
+
"""Plot trajectories for each entity."""
|
|
200
|
+
if traj.numel() == 0:
|
|
201
|
+
return
|
|
202
|
+
dim = traj.shape[2]
|
|
203
|
+
if dim == 2:
|
|
204
|
+
for idx in range(traj.shape[1]):
|
|
205
|
+
xy = traj[::step, idx]
|
|
206
|
+
ax.plot(xy[:, 0], xy[:, 1], label=f"{label} {idx}")
|
|
207
|
+
else:
|
|
208
|
+
for idx in range(traj.shape[1]):
|
|
209
|
+
xyz = traj[::step, idx]
|
|
210
|
+
ax.plot(xyz[:, 0], xyz[:, 1], xyz[:, 2], label=f"{label} {idx}")
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Higher-level plotting utilities used by examples."""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional, Sequence
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .plotting import plot_scene_dynamic, plot_scene_static
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def plot_scene_and_save(
|
|
14
|
+
*,
|
|
15
|
+
out_dir: Path,
|
|
16
|
+
room: Sequence[float] | torch.Tensor,
|
|
17
|
+
sources: object | torch.Tensor | Sequence,
|
|
18
|
+
mics: object | torch.Tensor | Sequence,
|
|
19
|
+
src_traj: Optional[torch.Tensor | Sequence] = None,
|
|
20
|
+
mic_traj: Optional[torch.Tensor | Sequence] = None,
|
|
21
|
+
prefix: str = "scene",
|
|
22
|
+
step: int = 1,
|
|
23
|
+
show: bool = False,
|
|
24
|
+
plot_2d: bool = True,
|
|
25
|
+
plot_3d: bool = True,
|
|
26
|
+
) -> tuple[list[Path], list[Path]]:
|
|
27
|
+
"""Plot static and dynamic scenes and save images to disk.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
out_dir: Output directory for PNGs.
|
|
31
|
+
room: Room size tensor or sequence.
|
|
32
|
+
sources: Source positions or Source-like object.
|
|
33
|
+
mics: Microphone positions or MicrophoneArray-like object.
|
|
34
|
+
src_traj: Optional source trajectory (T, n_src, dim).
|
|
35
|
+
mic_traj: Optional mic trajectory (T, n_mic, dim).
|
|
36
|
+
prefix: Filename prefix for saved images.
|
|
37
|
+
step: Subsampling step for trajectories.
|
|
38
|
+
show: Whether to show figures interactively.
|
|
39
|
+
plot_2d: Save 2D projections.
|
|
40
|
+
plot_3d: Save 3D projections (only if dim == 3).
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tuple of (static_paths, dynamic_paths).
|
|
44
|
+
"""
|
|
45
|
+
out_dir = Path(out_dir)
|
|
46
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
room_size = _to_cpu(room)
|
|
49
|
+
src_pos = _positions_to_cpu(sources)
|
|
50
|
+
mic_pos = _positions_to_cpu(mics)
|
|
51
|
+
dim = int(room_size.numel())
|
|
52
|
+
|
|
53
|
+
static_paths: list[Path] = []
|
|
54
|
+
dynamic_paths: list[Path] = []
|
|
55
|
+
|
|
56
|
+
for view_dim, enabled in ((2, plot_2d), (3, plot_3d)):
|
|
57
|
+
if not enabled:
|
|
58
|
+
continue
|
|
59
|
+
if view_dim == 2 and dim < 2:
|
|
60
|
+
continue
|
|
61
|
+
if view_dim == 3 and dim < 3:
|
|
62
|
+
continue
|
|
63
|
+
view_room = room_size[:view_dim]
|
|
64
|
+
view_src = src_pos[:, :view_dim]
|
|
65
|
+
view_mic = mic_pos[:, :view_dim]
|
|
66
|
+
|
|
67
|
+
ax = plot_scene_static(
|
|
68
|
+
room=view_room,
|
|
69
|
+
sources=view_src,
|
|
70
|
+
mics=view_mic,
|
|
71
|
+
title=f"Room scene ({view_dim}D static)",
|
|
72
|
+
show=False,
|
|
73
|
+
)
|
|
74
|
+
static_path = out_dir / f"{prefix}_static_{view_dim}d.png"
|
|
75
|
+
_save_axes(ax, static_path, show=show)
|
|
76
|
+
static_paths.append(static_path)
|
|
77
|
+
|
|
78
|
+
if src_traj is not None or mic_traj is not None:
|
|
79
|
+
steps = _traj_steps(src_traj, mic_traj)
|
|
80
|
+
src_traj = _trajectory_to_cpu(src_traj, src_pos, steps)
|
|
81
|
+
mic_traj = _trajectory_to_cpu(mic_traj, mic_pos, steps)
|
|
82
|
+
view_src_traj = src_traj[:, :, :view_dim]
|
|
83
|
+
view_mic_traj = mic_traj[:, :, :view_dim]
|
|
84
|
+
ax = plot_scene_dynamic(
|
|
85
|
+
room=view_room,
|
|
86
|
+
src_traj=view_src_traj,
|
|
87
|
+
mic_traj=view_mic_traj,
|
|
88
|
+
step=step,
|
|
89
|
+
title=f"Room scene ({view_dim}D trajectories)",
|
|
90
|
+
show=False,
|
|
91
|
+
)
|
|
92
|
+
_overlay_positions(ax, view_src, view_mic)
|
|
93
|
+
dynamic_path = out_dir / f"{prefix}_dynamic_{view_dim}d.png"
|
|
94
|
+
_save_axes(ax, dynamic_path, show=show)
|
|
95
|
+
dynamic_paths.append(dynamic_path)
|
|
96
|
+
|
|
97
|
+
return static_paths, dynamic_paths
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _to_cpu(value: Any) -> torch.Tensor:
|
|
101
|
+
"""Move a value to CPU as a tensor."""
|
|
102
|
+
if torch.is_tensor(value):
|
|
103
|
+
return value.detach().cpu()
|
|
104
|
+
return torch.as_tensor(value).detach().cpu()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
|
|
108
|
+
"""Extract positions from an entity and move to CPU."""
|
|
109
|
+
pos = getattr(entity, "positions", entity)
|
|
110
|
+
pos = _to_cpu(pos)
|
|
111
|
+
if pos.ndim == 1:
|
|
112
|
+
pos = pos.unsqueeze(0)
|
|
113
|
+
return pos
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _traj_steps(src_traj: Optional[torch.Tensor | Sequence], mic_traj: Optional[torch.Tensor | Sequence]) -> int:
|
|
117
|
+
"""Infer the number of trajectory steps."""
|
|
118
|
+
if src_traj is not None:
|
|
119
|
+
return int(_to_cpu(src_traj).shape[0])
|
|
120
|
+
return int(_to_cpu(mic_traj).shape[0])
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _trajectory_to_cpu(
|
|
124
|
+
traj: Optional[torch.Tensor | Sequence], fallback_pos: torch.Tensor, steps: int
|
|
125
|
+
) -> torch.Tensor:
|
|
126
|
+
"""Normalize trajectory to CPU tensor with shape (T, N, dim)."""
|
|
127
|
+
if traj is None:
|
|
128
|
+
return fallback_pos.unsqueeze(0).repeat(steps, 1, 1)
|
|
129
|
+
traj = _to_cpu(traj)
|
|
130
|
+
if traj.ndim != 3:
|
|
131
|
+
raise ValueError("trajectory must be of shape (T, N, dim)")
|
|
132
|
+
return traj
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _save_axes(ax: Any, path: Path, *, show: bool) -> None:
|
|
136
|
+
"""Save a matplotlib axis to disk."""
|
|
137
|
+
import matplotlib.pyplot as plt
|
|
138
|
+
|
|
139
|
+
fig = ax.figure
|
|
140
|
+
fig.tight_layout()
|
|
141
|
+
fig.savefig(path, dpi=150)
|
|
142
|
+
if show:
|
|
143
|
+
plt.show()
|
|
144
|
+
plt.close(fig)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _overlay_positions(ax: Any, sources: torch.Tensor, mics: torch.Tensor) -> None:
|
|
148
|
+
"""Overlay static source and mic positions on an axis."""
|
|
149
|
+
if sources.numel() > 0:
|
|
150
|
+
if sources.shape[1] == 2:
|
|
151
|
+
ax.scatter(sources[:, 0], sources[:, 1], marker="^", label="sources", color="tab:green")
|
|
152
|
+
else:
|
|
153
|
+
ax.scatter(
|
|
154
|
+
sources[:, 0],
|
|
155
|
+
sources[:, 1],
|
|
156
|
+
sources[:, 2],
|
|
157
|
+
marker="^",
|
|
158
|
+
label="sources",
|
|
159
|
+
color="tab:green",
|
|
160
|
+
)
|
|
161
|
+
if mics.numel() > 0:
|
|
162
|
+
if mics.shape[1] == 2:
|
|
163
|
+
ax.scatter(mics[:, 0], mics[:, 1], marker="o", label="mics", color="tab:orange")
|
|
164
|
+
else:
|
|
165
|
+
ax.scatter(
|
|
166
|
+
mics[:, 0],
|
|
167
|
+
mics[:, 1],
|
|
168
|
+
mics[:, 2],
|
|
169
|
+
marker="o",
|
|
170
|
+
label="mics",
|
|
171
|
+
color="tab:orange",
|
|
172
|
+
)
|
|
173
|
+
ax.legend(loc="best")
|
torchrir/results.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Result containers for simulation outputs."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from .config import SimulationConfig
|
|
11
|
+
from .scene import Scene
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class RIRResult:
|
|
16
|
+
"""Container for RIRs with metadata."""
|
|
17
|
+
|
|
18
|
+
rirs: Tensor
|
|
19
|
+
scene: Scene
|
|
20
|
+
config: SimulationConfig
|
|
21
|
+
timestamps: Optional[Tensor] = None
|
|
22
|
+
seed: Optional[int] = None
|
torchrir/room.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Room, source, and microphone geometry models."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
from typing import Optional, Sequence
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from .utils import as_tensor, ensure_dim
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class Room:
|
|
16
|
+
"""Room geometry and acoustic parameters."""
|
|
17
|
+
|
|
18
|
+
size: Tensor
|
|
19
|
+
fs: float
|
|
20
|
+
c: float = 343.0
|
|
21
|
+
beta: Optional[Tensor] = None
|
|
22
|
+
t60: Optional[float] = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
"""Validate room size and reflection parameters."""
|
|
26
|
+
size = ensure_dim(self.size)
|
|
27
|
+
object.__setattr__(self, "size", size)
|
|
28
|
+
if self.beta is not None and self.t60 is not None:
|
|
29
|
+
raise ValueError("beta and t60 are mutually exclusive")
|
|
30
|
+
|
|
31
|
+
def replace(self, **kwargs) -> "Room":
|
|
32
|
+
"""Return a new Room with updated fields."""
|
|
33
|
+
return replace(self, **kwargs)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def shoebox(
|
|
37
|
+
size: Sequence[float] | Tensor,
|
|
38
|
+
*,
|
|
39
|
+
fs: float,
|
|
40
|
+
c: float = 343.0,
|
|
41
|
+
beta: Optional[Sequence[float] | Tensor] = None,
|
|
42
|
+
t60: Optional[float] = None,
|
|
43
|
+
device: Optional[torch.device | str] = None,
|
|
44
|
+
dtype: Optional[torch.dtype] = None,
|
|
45
|
+
) -> "Room":
|
|
46
|
+
"""Create a rectangular (shoebox) room."""
|
|
47
|
+
size_t = as_tensor(size, device=device, dtype=dtype)
|
|
48
|
+
size_t = ensure_dim(size_t)
|
|
49
|
+
beta_t = None
|
|
50
|
+
if beta is not None:
|
|
51
|
+
beta_t = as_tensor(beta, device=device, dtype=dtype)
|
|
52
|
+
return Room(size=size_t, fs=fs, c=c, beta=beta_t, t60=t60)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(frozen=True)
|
|
56
|
+
class Source:
|
|
57
|
+
"""Source container with positions and optional orientation."""
|
|
58
|
+
|
|
59
|
+
positions: Tensor
|
|
60
|
+
orientation: Optional[Tensor] = None
|
|
61
|
+
|
|
62
|
+
def __post_init__(self) -> None:
|
|
63
|
+
pos = as_tensor(self.positions)
|
|
64
|
+
object.__setattr__(self, "positions", pos)
|
|
65
|
+
if self.orientation is not None:
|
|
66
|
+
ori = as_tensor(self.orientation)
|
|
67
|
+
object.__setattr__(self, "orientation", ori)
|
|
68
|
+
|
|
69
|
+
def replace(self, **kwargs) -> "Source":
|
|
70
|
+
"""Return a new Source with updated fields."""
|
|
71
|
+
return replace(self, **kwargs)
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def positions(
|
|
75
|
+
cls,
|
|
76
|
+
positions: Sequence[Sequence[float]] | Tensor,
|
|
77
|
+
*,
|
|
78
|
+
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
79
|
+
device: Optional[torch.device | str] = None,
|
|
80
|
+
dtype: Optional[torch.dtype] = None,
|
|
81
|
+
) -> "Source":
|
|
82
|
+
"""Construct a Source from positions."""
|
|
83
|
+
return cls.from_positions(
|
|
84
|
+
positions, orientation=orientation, device=device, dtype=dtype
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def from_positions(
|
|
89
|
+
cls,
|
|
90
|
+
positions: Sequence[Sequence[float]] | Tensor,
|
|
91
|
+
*,
|
|
92
|
+
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
93
|
+
device: Optional[torch.device | str] = None,
|
|
94
|
+
dtype: Optional[torch.dtype] = None,
|
|
95
|
+
) -> "Source":
|
|
96
|
+
"""Convert positions/orientation to tensors and build a Source."""
|
|
97
|
+
pos = as_tensor(positions, device=device, dtype=dtype)
|
|
98
|
+
ori = None
|
|
99
|
+
if orientation is not None:
|
|
100
|
+
ori = as_tensor(orientation, device=device, dtype=dtype)
|
|
101
|
+
return cls(pos, ori)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass(frozen=True)
|
|
105
|
+
class MicrophoneArray:
|
|
106
|
+
"""Microphone array container."""
|
|
107
|
+
|
|
108
|
+
positions: Tensor
|
|
109
|
+
orientation: Optional[Tensor] = None
|
|
110
|
+
|
|
111
|
+
def __post_init__(self) -> None:
|
|
112
|
+
pos = as_tensor(self.positions)
|
|
113
|
+
object.__setattr__(self, "positions", pos)
|
|
114
|
+
if self.orientation is not None:
|
|
115
|
+
ori = as_tensor(self.orientation)
|
|
116
|
+
object.__setattr__(self, "orientation", ori)
|
|
117
|
+
|
|
118
|
+
def replace(self, **kwargs) -> "MicrophoneArray":
|
|
119
|
+
"""Return a new MicrophoneArray with updated fields."""
|
|
120
|
+
return replace(self, **kwargs)
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def positions(
|
|
124
|
+
cls,
|
|
125
|
+
positions: Sequence[Sequence[float]] | Tensor,
|
|
126
|
+
*,
|
|
127
|
+
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
128
|
+
device: Optional[torch.device | str] = None,
|
|
129
|
+
dtype: Optional[torch.dtype] = None,
|
|
130
|
+
) -> "MicrophoneArray":
|
|
131
|
+
"""Construct a MicrophoneArray from positions."""
|
|
132
|
+
return cls.from_positions(
|
|
133
|
+
positions, orientation=orientation, device=device, dtype=dtype
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@classmethod
|
|
137
|
+
def from_positions(
|
|
138
|
+
cls,
|
|
139
|
+
positions: Sequence[Sequence[float]] | Tensor,
|
|
140
|
+
*,
|
|
141
|
+
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
142
|
+
device: Optional[torch.device | str] = None,
|
|
143
|
+
dtype: Optional[torch.dtype] = None,
|
|
144
|
+
) -> "MicrophoneArray":
|
|
145
|
+
"""Convert positions/orientation to tensors and build a MicrophoneArray."""
|
|
146
|
+
pos = as_tensor(positions, device=device, dtype=dtype)
|
|
147
|
+
ori = None
|
|
148
|
+
if orientation is not None:
|
|
149
|
+
ori = as_tensor(orientation, device=device, dtype=dtype)
|
|
150
|
+
return cls(pos, ori)
|
torchrir/scene.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Scene container for simulation inputs."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from .room import MicrophoneArray, Room, Source
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class Scene:
|
|
16
|
+
"""Container for room, sources, microphones, and optional trajectories."""
|
|
17
|
+
|
|
18
|
+
room: Room
|
|
19
|
+
sources: Source
|
|
20
|
+
mics: MicrophoneArray
|
|
21
|
+
src_traj: Optional[Tensor] = None
|
|
22
|
+
mic_traj: Optional[Tensor] = None
|
|
23
|
+
|
|
24
|
+
def is_dynamic(self) -> bool:
|
|
25
|
+
"""Return True if any trajectory is provided."""
|
|
26
|
+
return self.src_traj is not None or self.mic_traj is not None
|
|
27
|
+
|
|
28
|
+
def validate(self) -> None:
|
|
29
|
+
"""Validate scene consistency and trajectory shapes."""
|
|
30
|
+
if not isinstance(self.room, Room):
|
|
31
|
+
raise TypeError("room must be a Room instance")
|
|
32
|
+
if not isinstance(self.sources, Source):
|
|
33
|
+
raise TypeError("sources must be a Source instance")
|
|
34
|
+
if not isinstance(self.mics, MicrophoneArray):
|
|
35
|
+
raise TypeError("mics must be a MicrophoneArray instance")
|
|
36
|
+
|
|
37
|
+
dim = self.room.size.numel()
|
|
38
|
+
n_src = self.sources.positions.shape[0]
|
|
39
|
+
n_mic = self.mics.positions.shape[0]
|
|
40
|
+
|
|
41
|
+
t_src = _validate_traj(self.src_traj, n_src, dim, "src_traj")
|
|
42
|
+
t_mic = _validate_traj(self.mic_traj, n_mic, dim, "mic_traj")
|
|
43
|
+
if t_src is not None and t_mic is not None and t_src != t_mic:
|
|
44
|
+
raise ValueError("src_traj and mic_traj must have matching time steps")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _validate_traj(
|
|
48
|
+
traj: Optional[Tensor],
|
|
49
|
+
count: int,
|
|
50
|
+
dim: int,
|
|
51
|
+
name: str,
|
|
52
|
+
) -> Optional[int]:
|
|
53
|
+
if traj is None:
|
|
54
|
+
return None
|
|
55
|
+
if not torch.is_tensor(traj):
|
|
56
|
+
raise TypeError(f"{name} must be a Tensor")
|
|
57
|
+
if traj.ndim == 2:
|
|
58
|
+
if count != 1:
|
|
59
|
+
raise ValueError(f"{name} must have shape (T, {count}, {dim})")
|
|
60
|
+
if traj.shape[1] != dim:
|
|
61
|
+
raise ValueError(f"{name} must have shape (T, {dim}) for single entity")
|
|
62
|
+
return traj.shape[0]
|
|
63
|
+
if traj.ndim == 3:
|
|
64
|
+
if traj.shape[1] != count or traj.shape[2] != dim:
|
|
65
|
+
raise ValueError(f"{name} must have shape (T, {count}, {dim})")
|
|
66
|
+
return traj.shape[0]
|
|
67
|
+
raise ValueError(f"{name} must have shape (T, {count}, {dim})")
|
torchrir/scene_utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Scene-agnostic utilities for example setups."""
|
|
4
|
+
|
|
5
|
+
import random
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def sample_positions(
|
|
12
|
+
*,
|
|
13
|
+
num: int,
|
|
14
|
+
room_size: torch.Tensor,
|
|
15
|
+
rng: random.Random,
|
|
16
|
+
margin: float = 0.5,
|
|
17
|
+
) -> torch.Tensor:
|
|
18
|
+
"""Sample random positions within a room with a safety margin."""
|
|
19
|
+
dim = room_size.numel()
|
|
20
|
+
low = [margin] * dim
|
|
21
|
+
high = [float(room_size[i].item()) - margin for i in range(dim)]
|
|
22
|
+
coords: List[List[float]] = []
|
|
23
|
+
for _ in range(num):
|
|
24
|
+
point = [rng.uniform(low[i], high[i]) for i in range(dim)]
|
|
25
|
+
coords.append(point)
|
|
26
|
+
return torch.tensor(coords, dtype=torch.float32)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def linear_trajectory(start: torch.Tensor, end: torch.Tensor, steps: int) -> torch.Tensor:
|
|
30
|
+
"""Create a linear trajectory between start and end."""
|
|
31
|
+
return torch.stack(
|
|
32
|
+
[start + (end - start) * t / (steps - 1) for t in range(steps)],
|
|
33
|
+
dim=0,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.Tensor:
|
|
38
|
+
"""Create a two-mic binaural layout around a center point."""
|
|
39
|
+
dim = center.numel()
|
|
40
|
+
offset_vec = torch.zeros((dim,), dtype=torch.float32)
|
|
41
|
+
offset_vec[0] = offset
|
|
42
|
+
left = center - offset_vec
|
|
43
|
+
right = center + offset_vec
|
|
44
|
+
return torch.stack([left, right], dim=0)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def clamp_positions(positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1) -> torch.Tensor:
|
|
48
|
+
"""Clamp positions to remain inside the room with a margin."""
|
|
49
|
+
min_v = torch.full_like(room_size, margin)
|
|
50
|
+
max_v = room_size - margin
|
|
51
|
+
return torch.max(torch.min(positions, max_v), min_v)
|