torchrir 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +11 -1
- torchrir/animation.py +175 -0
- torchrir/config.py +11 -2
- torchrir/core.py +63 -16
- torchrir/datasets/cmu_arctic.py +21 -3
- torchrir/datasets/template.py +3 -1
- torchrir/datasets/utils.py +25 -3
- torchrir/dynamic.py +14 -3
- torchrir/logging_utils.py +17 -3
- torchrir/metadata.py +216 -0
- torchrir/plotting.py +124 -24
- torchrir/plotting_utils.py +19 -31
- torchrir/results.py +7 -1
- torchrir/room.py +20 -32
- torchrir/scene.py +6 -1
- torchrir/scene_utils.py +28 -6
- torchrir/signal.py +30 -10
- torchrir/simulators.py +17 -5
- torchrir/utils.py +40 -8
- torchrir-0.1.4.dist-info/METADATA +70 -0
- torchrir-0.1.4.dist-info/RECORD +28 -0
- torchrir-0.1.0.dist-info/METADATA +0 -213
- torchrir-0.1.0.dist-info/RECORD +0 -26
- {torchrir-0.1.0.dist-info → torchrir-0.1.4.dist-info}/WHEEL +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.4.dist-info}/licenses/NOTICE +0 -0
- {torchrir-0.1.0.dist-info → torchrir-0.1.4.dist-info}/top_level.txt +0 -0
torchrir/metadata.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Metadata helpers for simulation outputs."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
from .room import MicrophoneArray, Room, Source
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class ArrayAttributes:
|
|
18
|
+
"""Structured description of a microphone array."""
|
|
19
|
+
|
|
20
|
+
geometry_name: str
|
|
21
|
+
positions: Tensor
|
|
22
|
+
orientation: Optional[Tensor]
|
|
23
|
+
center: Tensor
|
|
24
|
+
normal: Optional[Tensor]
|
|
25
|
+
spacing: Optional[float]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def build_metadata(
|
|
29
|
+
*,
|
|
30
|
+
room: Room,
|
|
31
|
+
sources: Source,
|
|
32
|
+
mics: MicrophoneArray,
|
|
33
|
+
rirs: Tensor,
|
|
34
|
+
src_traj: Optional[Tensor] = None,
|
|
35
|
+
mic_traj: Optional[Tensor] = None,
|
|
36
|
+
timestamps: Optional[Tensor] = None,
|
|
37
|
+
signal_len: Optional[int] = None,
|
|
38
|
+
source_info: Optional[Any] = None,
|
|
39
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
40
|
+
) -> Dict[str, Any]:
|
|
41
|
+
"""Build JSON-serializable metadata for a simulation output.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> metadata = build_metadata(
|
|
45
|
+
... room=room,
|
|
46
|
+
... sources=sources,
|
|
47
|
+
... mics=mics,
|
|
48
|
+
... rirs=rirs,
|
|
49
|
+
... src_traj=src_traj,
|
|
50
|
+
... mic_traj=mic_traj,
|
|
51
|
+
... signal_len=signal.shape[-1],
|
|
52
|
+
... )
|
|
53
|
+
>>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
|
|
54
|
+
"""
|
|
55
|
+
nsample = int(rirs.shape[-1])
|
|
56
|
+
fs = float(room.fs)
|
|
57
|
+
time_axis = {
|
|
58
|
+
"fs": fs,
|
|
59
|
+
"nsample": nsample,
|
|
60
|
+
"t": _to_serializable(torch.arange(nsample, dtype=torch.float32) / fs),
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
src_pos = sources.positions
|
|
64
|
+
mic_pos = mics.positions
|
|
65
|
+
dim = int(room.size.numel())
|
|
66
|
+
src_traj_n = _normalize_traj(src_traj, src_pos, dim, "src_traj")
|
|
67
|
+
mic_traj_n = _normalize_traj(mic_traj, mic_pos, dim, "mic_traj")
|
|
68
|
+
|
|
69
|
+
t_steps = max(src_traj_n.shape[0], mic_traj_n.shape[0])
|
|
70
|
+
if src_traj_n.shape[0] == 1 and t_steps > 1:
|
|
71
|
+
src_traj_n = src_traj_n.expand(t_steps, -1, -1)
|
|
72
|
+
if mic_traj_n.shape[0] == 1 and t_steps > 1:
|
|
73
|
+
mic_traj_n = mic_traj_n.expand(t_steps, -1, -1)
|
|
74
|
+
if src_traj_n.shape[0] != mic_traj_n.shape[0]:
|
|
75
|
+
raise ValueError("src_traj and mic_traj must have matching time steps")
|
|
76
|
+
|
|
77
|
+
azimuth, elevation = _compute_doa(src_traj_n, mic_traj_n)
|
|
78
|
+
doa = {
|
|
79
|
+
"frame": "world",
|
|
80
|
+
"unit": "radians",
|
|
81
|
+
"azimuth": _to_serializable(azimuth),
|
|
82
|
+
"elevation": _to_serializable(elevation),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
timestamps_out: Optional[Tensor] = None
|
|
86
|
+
if timestamps is not None:
|
|
87
|
+
timestamps_out = timestamps
|
|
88
|
+
elif t_steps > 1 and signal_len is not None:
|
|
89
|
+
duration = max(0.0, (float(signal_len) - 1.0) / fs)
|
|
90
|
+
timestamps_out = torch.linspace(0.0, duration, t_steps, dtype=torch.float32)
|
|
91
|
+
|
|
92
|
+
array_attrs = _array_attributes(mics)
|
|
93
|
+
|
|
94
|
+
metadata: Dict[str, Any] = {
|
|
95
|
+
"room": {
|
|
96
|
+
"size": _to_serializable(room.size),
|
|
97
|
+
"c": float(room.c),
|
|
98
|
+
"beta": _to_serializable(room.beta) if room.beta is not None else None,
|
|
99
|
+
"t60": float(room.t60) if room.t60 is not None else None,
|
|
100
|
+
"fs": fs,
|
|
101
|
+
},
|
|
102
|
+
"sources": {
|
|
103
|
+
"positions": _to_serializable(src_pos),
|
|
104
|
+
"orientation": _to_serializable(sources.orientation),
|
|
105
|
+
},
|
|
106
|
+
"mics": {
|
|
107
|
+
"positions": _to_serializable(mic_pos),
|
|
108
|
+
"orientation": _to_serializable(mics.orientation),
|
|
109
|
+
},
|
|
110
|
+
"trajectories": {
|
|
111
|
+
"sources": _to_serializable(src_traj_n if t_steps > 1 else None),
|
|
112
|
+
"mics": _to_serializable(mic_traj_n if t_steps > 1 else None),
|
|
113
|
+
},
|
|
114
|
+
"array": {
|
|
115
|
+
"geometry": array_attrs.geometry_name,
|
|
116
|
+
"positions": _to_serializable(array_attrs.positions),
|
|
117
|
+
"orientation": _to_serializable(array_attrs.orientation),
|
|
118
|
+
"center": _to_serializable(array_attrs.center),
|
|
119
|
+
"normal": _to_serializable(array_attrs.normal),
|
|
120
|
+
"spacing": array_attrs.spacing,
|
|
121
|
+
},
|
|
122
|
+
"time_axis": time_axis,
|
|
123
|
+
"doa": doa,
|
|
124
|
+
"timestamps": _to_serializable(timestamps_out),
|
|
125
|
+
"rirs_shape": list(rirs.shape),
|
|
126
|
+
"dynamic": bool(t_steps > 1),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
if source_info is not None:
|
|
130
|
+
metadata["source_info"] = _to_serializable(source_info)
|
|
131
|
+
if extra:
|
|
132
|
+
metadata["extra"] = _to_serializable(extra)
|
|
133
|
+
return metadata
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def save_metadata_json(path: Path, metadata: Dict[str, Any]) -> None:
|
|
137
|
+
"""Save metadata as JSON to the given path.
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
>>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
|
|
141
|
+
"""
|
|
142
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
143
|
+
with path.open("w", encoding="utf-8") as f:
|
|
144
|
+
json.dump(metadata, f, indent=2)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _normalize_traj(traj: Optional[Tensor], pos: Tensor, dim: int, name: str) -> Tensor:
|
|
148
|
+
if traj is None:
|
|
149
|
+
if pos.ndim != 2 or pos.shape[1] != dim:
|
|
150
|
+
raise ValueError(f"{name} default positions must have shape (N, {dim})")
|
|
151
|
+
return pos.unsqueeze(0)
|
|
152
|
+
if not torch.is_tensor(traj):
|
|
153
|
+
raise TypeError(f"{name} must be a Tensor")
|
|
154
|
+
if traj.ndim == 2:
|
|
155
|
+
if traj.shape[1] != dim:
|
|
156
|
+
raise ValueError(f"{name} must have shape (T, {dim})")
|
|
157
|
+
return traj.unsqueeze(1)
|
|
158
|
+
if traj.ndim == 3:
|
|
159
|
+
if traj.shape[2] != dim:
|
|
160
|
+
raise ValueError(f"{name} must have shape (T, N, {dim})")
|
|
161
|
+
return traj
|
|
162
|
+
raise ValueError(f"{name} must have shape (T, N, {dim})")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _compute_doa(src_traj: Tensor, mic_traj: Tensor) -> tuple[Tensor, Tensor]:
|
|
166
|
+
vec = src_traj[:, :, None, :] - mic_traj[:, None, :, :]
|
|
167
|
+
x = vec[..., 0]
|
|
168
|
+
y = vec[..., 1]
|
|
169
|
+
azimuth = torch.atan2(y, x)
|
|
170
|
+
if vec.shape[-1] < 3:
|
|
171
|
+
elevation = torch.zeros_like(azimuth)
|
|
172
|
+
else:
|
|
173
|
+
z = vec[..., 2]
|
|
174
|
+
r_xy = torch.sqrt(x**2 + y**2)
|
|
175
|
+
elevation = torch.atan2(z, r_xy)
|
|
176
|
+
return azimuth, elevation
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _array_attributes(mics: MicrophoneArray) -> ArrayAttributes:
|
|
180
|
+
pos = mics.positions
|
|
181
|
+
n_mic = pos.shape[0]
|
|
182
|
+
if n_mic == 1:
|
|
183
|
+
geometry = "single"
|
|
184
|
+
elif n_mic == 2:
|
|
185
|
+
geometry = "binaural"
|
|
186
|
+
else:
|
|
187
|
+
geometry = "custom"
|
|
188
|
+
center = pos.mean(dim=0)
|
|
189
|
+
spacing = None
|
|
190
|
+
if n_mic >= 2:
|
|
191
|
+
dists = torch.cdist(pos, pos)
|
|
192
|
+
dists = dists[dists > 0]
|
|
193
|
+
if dists.numel() > 0:
|
|
194
|
+
spacing = float(dists.min().item())
|
|
195
|
+
return ArrayAttributes(
|
|
196
|
+
geometry_name=geometry,
|
|
197
|
+
positions=pos,
|
|
198
|
+
orientation=mics.orientation,
|
|
199
|
+
center=center,
|
|
200
|
+
normal=None,
|
|
201
|
+
spacing=spacing,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _to_serializable(value: Any) -> Any:
|
|
206
|
+
if value is None:
|
|
207
|
+
return None
|
|
208
|
+
if torch.is_tensor(value):
|
|
209
|
+
return value.detach().cpu().tolist()
|
|
210
|
+
if isinstance(value, Path):
|
|
211
|
+
return str(value)
|
|
212
|
+
if isinstance(value, dict):
|
|
213
|
+
return {k: _to_serializable(v) for k, v in value.items()}
|
|
214
|
+
if isinstance(value, (list, tuple)):
|
|
215
|
+
return [_to_serializable(v) for v in value]
|
|
216
|
+
return value
|
torchrir/plotting.py
CHANGED
|
@@ -20,7 +20,15 @@ def plot_scene_static(
|
|
|
20
20
|
title: Optional[str] = None,
|
|
21
21
|
show: bool = False,
|
|
22
22
|
):
|
|
23
|
-
"""Plot a static room with source and mic positions.
|
|
23
|
+
"""Plot a static room with source and mic positions.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> ax = plot_scene_static(
|
|
27
|
+
... room=[6.0, 4.0, 3.0],
|
|
28
|
+
... sources=[[1.0, 2.0, 1.5]],
|
|
29
|
+
... mics=[[2.0, 2.0, 1.5]],
|
|
30
|
+
... )
|
|
31
|
+
"""
|
|
24
32
|
plt, ax = _setup_axes(ax, room)
|
|
25
33
|
|
|
26
34
|
size = _room_size(room, ax)
|
|
@@ -46,21 +54,35 @@ def plot_scene_dynamic(
|
|
|
46
54
|
src_traj: Tensor | Sequence,
|
|
47
55
|
mic_traj: Tensor | Sequence,
|
|
48
56
|
step: int = 1,
|
|
57
|
+
src_pos: Optional[Tensor | Sequence] = None,
|
|
58
|
+
mic_pos: Optional[Tensor | Sequence] = None,
|
|
49
59
|
ax: Any | None = None,
|
|
50
60
|
title: Optional[str] = None,
|
|
51
61
|
show: bool = False,
|
|
52
62
|
):
|
|
53
|
-
"""Plot source and mic trajectories within a room.
|
|
63
|
+
"""Plot source and mic trajectories within a room.
|
|
64
|
+
|
|
65
|
+
If trajectories are static, only positions are plotted.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> ax = plot_scene_dynamic(
|
|
69
|
+
... room=[6.0, 4.0, 3.0],
|
|
70
|
+
... src_traj=src_traj,
|
|
71
|
+
... mic_traj=mic_traj,
|
|
72
|
+
... )
|
|
73
|
+
"""
|
|
54
74
|
plt, ax = _setup_axes(ax, room)
|
|
55
75
|
|
|
56
76
|
size = _room_size(room, ax)
|
|
57
77
|
_draw_room(ax, size)
|
|
58
78
|
|
|
59
|
-
src_traj = _as_trajectory(src_traj
|
|
60
|
-
mic_traj = _as_trajectory(mic_traj
|
|
79
|
+
src_traj = _as_trajectory(src_traj)
|
|
80
|
+
mic_traj = _as_trajectory(mic_traj)
|
|
81
|
+
src_pos_t = _extract_positions(src_pos, ax) if src_pos is not None else src_traj[0]
|
|
82
|
+
mic_pos_t = _extract_positions(mic_pos, ax) if mic_pos is not None else mic_traj[0]
|
|
61
83
|
|
|
62
|
-
|
|
63
|
-
|
|
84
|
+
_plot_entity(ax, src_traj, src_pos_t, step=step, label="sources", marker="^")
|
|
85
|
+
_plot_entity(ax, mic_traj, mic_pos_t, step=step, label="mics", marker="o")
|
|
64
86
|
|
|
65
87
|
if title:
|
|
66
88
|
ax.set_title(title)
|
|
@@ -70,7 +92,9 @@ def plot_scene_dynamic(
|
|
|
70
92
|
return ax
|
|
71
93
|
|
|
72
94
|
|
|
73
|
-
def _setup_axes(
|
|
95
|
+
def _setup_axes(
|
|
96
|
+
ax: Any | None, room: Room | Sequence[float] | Tensor
|
|
97
|
+
) -> tuple[Any, Any]:
|
|
74
98
|
"""Create 2D/3D axes based on room dimension."""
|
|
75
99
|
import matplotlib.pyplot as plt
|
|
76
100
|
|
|
@@ -109,8 +133,9 @@ def _draw_room_2d(ax: Any, size: Tensor) -> None:
|
|
|
109
133
|
"""Draw a 2D rectangular room."""
|
|
110
134
|
import matplotlib.patches as patches
|
|
111
135
|
|
|
112
|
-
rect = patches.Rectangle(
|
|
113
|
-
|
|
136
|
+
rect = patches.Rectangle(
|
|
137
|
+
(0.0, 0.0), size[0].item(), size[1].item(), fill=False, edgecolor="black"
|
|
138
|
+
)
|
|
114
139
|
ax.add_patch(rect)
|
|
115
140
|
ax.set_xlim(0, size[0].item())
|
|
116
141
|
ax.set_ylim(0, size[1].item())
|
|
@@ -164,7 +189,9 @@ def _draw_room_3d(ax: Any, size: Tensor) -> None:
|
|
|
164
189
|
ax.set_zlabel("z")
|
|
165
190
|
|
|
166
191
|
|
|
167
|
-
def _extract_positions(
|
|
192
|
+
def _extract_positions(
|
|
193
|
+
entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None
|
|
194
|
+
) -> Tensor:
|
|
168
195
|
"""Extract positions from Source/MicrophoneArray or raw tensor."""
|
|
169
196
|
if isinstance(entity, (Source, MicrophoneArray)):
|
|
170
197
|
pos = entity.positions
|
|
@@ -176,18 +203,34 @@ def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax:
|
|
|
176
203
|
return pos
|
|
177
204
|
|
|
178
205
|
|
|
179
|
-
def _scatter_positions(
|
|
206
|
+
def _scatter_positions(
|
|
207
|
+
ax: Any,
|
|
208
|
+
positions: Tensor,
|
|
209
|
+
*,
|
|
210
|
+
label: str,
|
|
211
|
+
marker: str,
|
|
212
|
+
color: Optional[str] = None,
|
|
213
|
+
) -> None:
|
|
180
214
|
"""Scatter-plot positions in 2D or 3D."""
|
|
181
215
|
if positions.numel() == 0:
|
|
182
216
|
return
|
|
183
217
|
dim = positions.shape[1]
|
|
184
218
|
if dim == 2:
|
|
185
|
-
ax.scatter(
|
|
219
|
+
ax.scatter(
|
|
220
|
+
positions[:, 0], positions[:, 1], label=label, marker=marker, color=color
|
|
221
|
+
)
|
|
186
222
|
else:
|
|
187
|
-
ax.scatter(
|
|
223
|
+
ax.scatter(
|
|
224
|
+
positions[:, 0],
|
|
225
|
+
positions[:, 1],
|
|
226
|
+
positions[:, 2],
|
|
227
|
+
label=label,
|
|
228
|
+
marker=marker,
|
|
229
|
+
color=color,
|
|
230
|
+
)
|
|
188
231
|
|
|
189
232
|
|
|
190
|
-
def _as_trajectory(traj: Tensor | Sequence
|
|
233
|
+
def _as_trajectory(traj: Tensor | Sequence) -> Tensor:
|
|
191
234
|
"""Validate and normalize a trajectory tensor."""
|
|
192
235
|
traj = as_tensor(traj)
|
|
193
236
|
if traj.ndim != 3:
|
|
@@ -195,16 +238,73 @@ def _as_trajectory(traj: Tensor | Sequence, ax: Any | None) -> Tensor:
|
|
|
195
238
|
return traj
|
|
196
239
|
|
|
197
240
|
|
|
198
|
-
def
|
|
199
|
-
|
|
241
|
+
def _plot_entity(
|
|
242
|
+
ax: Any,
|
|
243
|
+
traj: Tensor,
|
|
244
|
+
positions: Tensor,
|
|
245
|
+
*,
|
|
246
|
+
step: int,
|
|
247
|
+
label: str,
|
|
248
|
+
marker: str,
|
|
249
|
+
) -> None:
|
|
250
|
+
"""Plot trajectories and/or static positions with a unified legend entry."""
|
|
200
251
|
if traj.numel() == 0:
|
|
201
252
|
return
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
253
|
+
import matplotlib.pyplot as plt
|
|
254
|
+
|
|
255
|
+
if positions.shape != traj.shape[1:]:
|
|
256
|
+
positions = traj[0]
|
|
257
|
+
moving = _is_moving(traj, positions)
|
|
258
|
+
colors = plt.rcParams.get("axes.prop_cycle", None)
|
|
259
|
+
if colors is not None:
|
|
260
|
+
palette = colors.by_key().get("color", [])
|
|
207
261
|
else:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
262
|
+
palette = []
|
|
263
|
+
if not palette:
|
|
264
|
+
palette = ["C0", "C1", "C2", "C3", "C4", "C5"]
|
|
265
|
+
|
|
266
|
+
dim = traj.shape[2]
|
|
267
|
+
for idx in range(traj.shape[1]):
|
|
268
|
+
color = palette[idx % len(palette)]
|
|
269
|
+
lbl = label if idx == 0 else "_nolegend_"
|
|
270
|
+
if moving:
|
|
271
|
+
if dim == 2:
|
|
272
|
+
xy = traj[::step, idx]
|
|
273
|
+
ax.plot(
|
|
274
|
+
xy[:, 0],
|
|
275
|
+
xy[:, 1],
|
|
276
|
+
label=lbl,
|
|
277
|
+
color=color,
|
|
278
|
+
marker=marker,
|
|
279
|
+
markevery=[0],
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
xyz = traj[::step, idx]
|
|
283
|
+
ax.plot(
|
|
284
|
+
xyz[:, 0],
|
|
285
|
+
xyz[:, 1],
|
|
286
|
+
xyz[:, 2],
|
|
287
|
+
label=lbl,
|
|
288
|
+
color=color,
|
|
289
|
+
marker=marker,
|
|
290
|
+
markevery=[0],
|
|
291
|
+
)
|
|
292
|
+
pos = positions[idx : idx + 1]
|
|
293
|
+
_scatter_positions(ax, pos, label="_nolegend_", marker=marker, color=color)
|
|
294
|
+
if not moving:
|
|
295
|
+
# ensure legend marker uses the group label
|
|
296
|
+
_scatter_positions(
|
|
297
|
+
ax,
|
|
298
|
+
positions[:1],
|
|
299
|
+
label=label,
|
|
300
|
+
marker=marker,
|
|
301
|
+
color=palette[0],
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _is_moving(traj: Tensor, positions: Tensor, *, tol: float = 1e-6) -> bool:
|
|
306
|
+
"""Return True if any trajectory deviates from the provided positions."""
|
|
307
|
+
if traj.numel() == 0:
|
|
308
|
+
return False
|
|
309
|
+
pos0 = positions.unsqueeze(0).expand_as(traj)
|
|
310
|
+
return bool(torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item())
|
torchrir/plotting_utils.py
CHANGED
|
@@ -26,6 +26,8 @@ def plot_scene_and_save(
|
|
|
26
26
|
) -> tuple[list[Path], list[Path]]:
|
|
27
27
|
"""Plot static and dynamic scenes and save images to disk.
|
|
28
28
|
|
|
29
|
+
Dynamic plots show trajectories for moving entities and points for fixed ones.
|
|
30
|
+
|
|
29
31
|
Args:
|
|
30
32
|
out_dir: Output directory for PNGs.
|
|
31
33
|
room: Room size tensor or sequence.
|
|
@@ -41,6 +43,17 @@ def plot_scene_and_save(
|
|
|
41
43
|
|
|
42
44
|
Returns:
|
|
43
45
|
Tuple of (static_paths, dynamic_paths).
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> plot_scene_and_save(
|
|
49
|
+
... out_dir=Path("outputs"),
|
|
50
|
+
... room=[6.0, 4.0, 3.0],
|
|
51
|
+
... sources=[[1.0, 2.0, 1.5]],
|
|
52
|
+
... mics=[[2.0, 2.0, 1.5]],
|
|
53
|
+
... src_traj=src_traj,
|
|
54
|
+
... mic_traj=mic_traj,
|
|
55
|
+
... prefix="scene",
|
|
56
|
+
... )
|
|
44
57
|
"""
|
|
45
58
|
out_dir = Path(out_dir)
|
|
46
59
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -85,11 +98,12 @@ def plot_scene_and_save(
|
|
|
85
98
|
room=view_room,
|
|
86
99
|
src_traj=view_src_traj,
|
|
87
100
|
mic_traj=view_mic_traj,
|
|
101
|
+
src_pos=view_src,
|
|
102
|
+
mic_pos=view_mic,
|
|
88
103
|
step=step,
|
|
89
104
|
title=f"Room scene ({view_dim}D trajectories)",
|
|
90
105
|
show=False,
|
|
91
106
|
)
|
|
92
|
-
_overlay_positions(ax, view_src, view_mic)
|
|
93
107
|
dynamic_path = out_dir / f"{prefix}_dynamic_{view_dim}d.png"
|
|
94
108
|
_save_axes(ax, dynamic_path, show=show)
|
|
95
109
|
dynamic_paths.append(dynamic_path)
|
|
@@ -113,7 +127,10 @@ def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
|
|
|
113
127
|
return pos
|
|
114
128
|
|
|
115
129
|
|
|
116
|
-
def _traj_steps(
|
|
130
|
+
def _traj_steps(
|
|
131
|
+
src_traj: Optional[torch.Tensor | Sequence],
|
|
132
|
+
mic_traj: Optional[torch.Tensor | Sequence],
|
|
133
|
+
) -> int:
|
|
117
134
|
"""Infer the number of trajectory steps."""
|
|
118
135
|
if src_traj is not None:
|
|
119
136
|
return int(_to_cpu(src_traj).shape[0])
|
|
@@ -142,32 +159,3 @@ def _save_axes(ax: Any, path: Path, *, show: bool) -> None:
|
|
|
142
159
|
if show:
|
|
143
160
|
plt.show()
|
|
144
161
|
plt.close(fig)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def _overlay_positions(ax: Any, sources: torch.Tensor, mics: torch.Tensor) -> None:
|
|
148
|
-
"""Overlay static source and mic positions on an axis."""
|
|
149
|
-
if sources.numel() > 0:
|
|
150
|
-
if sources.shape[1] == 2:
|
|
151
|
-
ax.scatter(sources[:, 0], sources[:, 1], marker="^", label="sources", color="tab:green")
|
|
152
|
-
else:
|
|
153
|
-
ax.scatter(
|
|
154
|
-
sources[:, 0],
|
|
155
|
-
sources[:, 1],
|
|
156
|
-
sources[:, 2],
|
|
157
|
-
marker="^",
|
|
158
|
-
label="sources",
|
|
159
|
-
color="tab:green",
|
|
160
|
-
)
|
|
161
|
-
if mics.numel() > 0:
|
|
162
|
-
if mics.shape[1] == 2:
|
|
163
|
-
ax.scatter(mics[:, 0], mics[:, 1], marker="o", label="mics", color="tab:orange")
|
|
164
|
-
else:
|
|
165
|
-
ax.scatter(
|
|
166
|
-
mics[:, 0],
|
|
167
|
-
mics[:, 1],
|
|
168
|
-
mics[:, 2],
|
|
169
|
-
marker="o",
|
|
170
|
-
label="mics",
|
|
171
|
-
color="tab:orange",
|
|
172
|
-
)
|
|
173
|
-
ax.legend(loc="best")
|
torchrir/results.py
CHANGED
|
@@ -13,7 +13,13 @@ from .scene import Scene
|
|
|
13
13
|
|
|
14
14
|
@dataclass(frozen=True)
|
|
15
15
|
class RIRResult:
|
|
16
|
-
"""Container for RIRs with metadata.
|
|
16
|
+
"""Container for RIRs with metadata.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> from torchrir import ISMSimulator
|
|
20
|
+
>>> result = ISMSimulator().simulate(scene, config)
|
|
21
|
+
>>> rirs = result.rirs
|
|
22
|
+
"""
|
|
17
23
|
|
|
18
24
|
rirs: Tensor
|
|
19
25
|
scene: Scene
|
torchrir/room.py
CHANGED
|
@@ -13,7 +13,11 @@ from .utils import as_tensor, ensure_dim
|
|
|
13
13
|
|
|
14
14
|
@dataclass(frozen=True)
|
|
15
15
|
class Room:
|
|
16
|
-
"""Room geometry and acoustic parameters.
|
|
16
|
+
"""Room geometry and acoustic parameters.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
20
|
+
"""
|
|
17
21
|
|
|
18
22
|
size: Tensor
|
|
19
23
|
fs: float
|
|
@@ -43,7 +47,11 @@ class Room:
|
|
|
43
47
|
device: Optional[torch.device | str] = None,
|
|
44
48
|
dtype: Optional[torch.dtype] = None,
|
|
45
49
|
) -> "Room":
|
|
46
|
-
"""Create a rectangular (shoebox) room.
|
|
50
|
+
"""Create a rectangular (shoebox) room.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
54
|
+
"""
|
|
47
55
|
size_t = as_tensor(size, device=device, dtype=dtype)
|
|
48
56
|
size_t = ensure_dim(size_t)
|
|
49
57
|
beta_t = None
|
|
@@ -54,7 +62,11 @@ class Room:
|
|
|
54
62
|
|
|
55
63
|
@dataclass(frozen=True)
|
|
56
64
|
class Source:
|
|
57
|
-
"""Source container with positions and optional orientation.
|
|
65
|
+
"""Source container with positions and optional orientation.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
69
|
+
"""
|
|
58
70
|
|
|
59
71
|
positions: Tensor
|
|
60
72
|
orientation: Optional[Tensor] = None
|
|
@@ -70,20 +82,6 @@ class Source:
|
|
|
70
82
|
"""Return a new Source with updated fields."""
|
|
71
83
|
return replace(self, **kwargs)
|
|
72
84
|
|
|
73
|
-
@classmethod
|
|
74
|
-
def positions(
|
|
75
|
-
cls,
|
|
76
|
-
positions: Sequence[Sequence[float]] | Tensor,
|
|
77
|
-
*,
|
|
78
|
-
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
79
|
-
device: Optional[torch.device | str] = None,
|
|
80
|
-
dtype: Optional[torch.dtype] = None,
|
|
81
|
-
) -> "Source":
|
|
82
|
-
"""Construct a Source from positions."""
|
|
83
|
-
return cls.from_positions(
|
|
84
|
-
positions, orientation=orientation, device=device, dtype=dtype
|
|
85
|
-
)
|
|
86
|
-
|
|
87
85
|
@classmethod
|
|
88
86
|
def from_positions(
|
|
89
87
|
cls,
|
|
@@ -103,7 +101,11 @@ class Source:
|
|
|
103
101
|
|
|
104
102
|
@dataclass(frozen=True)
|
|
105
103
|
class MicrophoneArray:
|
|
106
|
-
"""Microphone array container.
|
|
104
|
+
"""Microphone array container.
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
108
|
+
"""
|
|
107
109
|
|
|
108
110
|
positions: Tensor
|
|
109
111
|
orientation: Optional[Tensor] = None
|
|
@@ -119,20 +121,6 @@ class MicrophoneArray:
|
|
|
119
121
|
"""Return a new MicrophoneArray with updated fields."""
|
|
120
122
|
return replace(self, **kwargs)
|
|
121
123
|
|
|
122
|
-
@classmethod
|
|
123
|
-
def positions(
|
|
124
|
-
cls,
|
|
125
|
-
positions: Sequence[Sequence[float]] | Tensor,
|
|
126
|
-
*,
|
|
127
|
-
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
128
|
-
device: Optional[torch.device | str] = None,
|
|
129
|
-
dtype: Optional[torch.dtype] = None,
|
|
130
|
-
) -> "MicrophoneArray":
|
|
131
|
-
"""Construct a MicrophoneArray from positions."""
|
|
132
|
-
return cls.from_positions(
|
|
133
|
-
positions, orientation=orientation, device=device, dtype=dtype
|
|
134
|
-
)
|
|
135
|
-
|
|
136
124
|
@classmethod
|
|
137
125
|
def from_positions(
|
|
138
126
|
cls,
|
torchrir/scene.py
CHANGED
|
@@ -13,7 +13,12 @@ from .room import MicrophoneArray, Room, Source
|
|
|
13
13
|
|
|
14
14
|
@dataclass(frozen=True)
|
|
15
15
|
class Scene:
|
|
16
|
-
"""Container for room, sources, microphones, and optional trajectories.
|
|
16
|
+
"""Container for room, sources, microphones, and optional trajectories.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> scene = Scene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
|
|
20
|
+
>>> scene.validate()
|
|
21
|
+
"""
|
|
17
22
|
|
|
18
23
|
room: Room
|
|
19
24
|
sources: Source
|