torchrir 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +16 -1
- torchrir/animation.py +17 -14
- torchrir/core.py +176 -35
- torchrir/datasets/__init__.py +9 -3
- torchrir/datasets/base.py +43 -3
- torchrir/datasets/cmu_arctic.py +9 -20
- torchrir/datasets/collate.py +90 -0
- torchrir/datasets/librispeech.py +175 -0
- torchrir/datasets/template.py +3 -1
- torchrir/datasets/utils.py +23 -1
- torchrir/dynamic.py +3 -1
- torchrir/plotting.py +13 -6
- torchrir/plotting_utils.py +4 -1
- torchrir/room.py +2 -38
- torchrir/scene_utils.py +6 -2
- torchrir/signal.py +24 -10
- torchrir/simulators.py +12 -4
- torchrir/utils.py +1 -1
- torchrir-0.2.0.dist-info/METADATA +70 -0
- torchrir-0.2.0.dist-info/RECORD +30 -0
- torchrir-0.1.2.dist-info/METADATA +0 -271
- torchrir-0.1.2.dist-info/RECORD +0 -28
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/WHEEL +0 -0
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/licenses/NOTICE +0 -0
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""LibriSpeech dataset helpers."""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import tarfile
|
|
7
|
+
import urllib.request
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import List, Tuple
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .base import BaseDataset
|
|
15
|
+
from .utils import load_wav_mono
|
|
16
|
+
|
|
17
|
+
BASE_URL = "https://www.openslr.org/resources/12"
|
|
18
|
+
VALID_SUBSETS = {
|
|
19
|
+
"dev-clean",
|
|
20
|
+
"dev-other",
|
|
21
|
+
"test-clean",
|
|
22
|
+
"test-other",
|
|
23
|
+
"train-clean-100",
|
|
24
|
+
"train-clean-360",
|
|
25
|
+
"train-other-500",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LibriSpeechSentence:
|
|
33
|
+
"""Sentence metadata from LibriSpeech."""
|
|
34
|
+
|
|
35
|
+
utterance_id: str
|
|
36
|
+
text: str
|
|
37
|
+
speaker_id: str
|
|
38
|
+
chapter_id: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LibriSpeechDataset(BaseDataset):
|
|
42
|
+
"""LibriSpeech dataset loader.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> dataset = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100", download=True)
|
|
46
|
+
>>> audio, fs = dataset.load_wav("103-1240-0000")
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self, root: Path, subset: str = "train-clean-100", download: bool = False
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Initialize a LibriSpeech dataset handle.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
root: Root directory where the dataset is stored.
|
|
56
|
+
subset: LibriSpeech subset name (e.g., "train-clean-100").
|
|
57
|
+
download: Download and extract if missing.
|
|
58
|
+
"""
|
|
59
|
+
if subset not in VALID_SUBSETS:
|
|
60
|
+
raise ValueError(f"unsupported subset: {subset}")
|
|
61
|
+
self.root = Path(root)
|
|
62
|
+
self.subset = subset
|
|
63
|
+
self._archive_name = f"{subset}.tar.gz"
|
|
64
|
+
self._base_dir = self.root / "LibriSpeech"
|
|
65
|
+
self._subset_dir = self._base_dir / subset
|
|
66
|
+
|
|
67
|
+
if download:
|
|
68
|
+
self._download_and_extract()
|
|
69
|
+
|
|
70
|
+
if not self._subset_dir.exists():
|
|
71
|
+
raise FileNotFoundError(
|
|
72
|
+
"dataset not found; run with download=True or place the archive under "
|
|
73
|
+
f"{self.root}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def list_speakers(self) -> List[str]:
|
|
77
|
+
"""Return available speaker IDs."""
|
|
78
|
+
if not self._subset_dir.exists():
|
|
79
|
+
return []
|
|
80
|
+
return sorted([p.name for p in self._subset_dir.iterdir() if p.is_dir()])
|
|
81
|
+
|
|
82
|
+
def available_sentences(self) -> List[LibriSpeechSentence]:
|
|
83
|
+
"""Return sentences that have a corresponding audio file."""
|
|
84
|
+
sentences: List[LibriSpeechSentence] = []
|
|
85
|
+
for trans_path in self._subset_dir.rglob("*.trans.txt"):
|
|
86
|
+
chapter_dir = trans_path.parent
|
|
87
|
+
speaker_id = chapter_dir.parent.name
|
|
88
|
+
chapter_id = chapter_dir.name
|
|
89
|
+
with trans_path.open("r", encoding="utf-8") as f:
|
|
90
|
+
for line in f:
|
|
91
|
+
line = line.strip()
|
|
92
|
+
if not line:
|
|
93
|
+
continue
|
|
94
|
+
utt_id, text = _parse_text_line(line)
|
|
95
|
+
wav_path = chapter_dir / f"{utt_id}.flac"
|
|
96
|
+
if wav_path.exists():
|
|
97
|
+
sentences.append(
|
|
98
|
+
LibriSpeechSentence(
|
|
99
|
+
utterance_id=utt_id,
|
|
100
|
+
text=text,
|
|
101
|
+
speaker_id=speaker_id,
|
|
102
|
+
chapter_id=chapter_id,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
return sentences
|
|
106
|
+
|
|
107
|
+
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
108
|
+
"""Load a mono wav for the given utterance ID."""
|
|
109
|
+
speaker_id, chapter_id, _ = utterance_id.split("-", 2)
|
|
110
|
+
path = self._subset_dir / speaker_id / chapter_id / f"{utterance_id}.flac"
|
|
111
|
+
return load_wav_mono(path)
|
|
112
|
+
|
|
113
|
+
def _download_and_extract(self) -> None:
|
|
114
|
+
"""Download and extract the subset archive if needed."""
|
|
115
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
archive_path = self.root / self._archive_name
|
|
117
|
+
url = f"{BASE_URL}/{self._archive_name}"
|
|
118
|
+
|
|
119
|
+
if not archive_path.exists():
|
|
120
|
+
logger.info("Downloading %s", url)
|
|
121
|
+
_download(url, archive_path)
|
|
122
|
+
if not self._subset_dir.exists():
|
|
123
|
+
logger.info("Extracting %s", archive_path)
|
|
124
|
+
try:
|
|
125
|
+
with tarfile.open(archive_path, "r:gz") as tar:
|
|
126
|
+
tar.extractall(self.root)
|
|
127
|
+
except (tarfile.ReadError, EOFError, OSError) as exc:
|
|
128
|
+
logger.warning("Extraction failed (%s); re-downloading.", exc)
|
|
129
|
+
if archive_path.exists():
|
|
130
|
+
archive_path.unlink()
|
|
131
|
+
_download(url, archive_path)
|
|
132
|
+
with tarfile.open(archive_path, "r:gz") as tar:
|
|
133
|
+
tar.extractall(self.root)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _download(url: str, dest: Path, retries: int = 1) -> None:
|
|
137
|
+
"""Download a file with retry and resume-safe temp file."""
|
|
138
|
+
for attempt in range(retries + 1):
|
|
139
|
+
try:
|
|
140
|
+
_stream_download(url, dest)
|
|
141
|
+
return
|
|
142
|
+
except Exception as exc:
|
|
143
|
+
if dest.exists():
|
|
144
|
+
dest.unlink()
|
|
145
|
+
if attempt >= retries:
|
|
146
|
+
raise
|
|
147
|
+
logger.warning("Download failed (%s); retrying...", exc)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _stream_download(url: str, dest: Path) -> None:
|
|
151
|
+
"""Stream a URL to disk with a progress indicator."""
|
|
152
|
+
tmp_path = dest.with_suffix(dest.suffix + ".part")
|
|
153
|
+
if tmp_path.exists():
|
|
154
|
+
tmp_path.unlink()
|
|
155
|
+
|
|
156
|
+
with urllib.request.urlopen(url) as response:
|
|
157
|
+
total = response.length or 0
|
|
158
|
+
downloaded = 0
|
|
159
|
+
chunk_size = 1024 * 1024
|
|
160
|
+
with tmp_path.open("wb") as f:
|
|
161
|
+
while True:
|
|
162
|
+
chunk = response.read(chunk_size)
|
|
163
|
+
if not chunk:
|
|
164
|
+
break
|
|
165
|
+
f.write(chunk)
|
|
166
|
+
downloaded += len(chunk)
|
|
167
|
+
if total > 0 and downloaded != total:
|
|
168
|
+
raise IOError(f"incomplete download: {downloaded} of {total} bytes")
|
|
169
|
+
tmp_path.replace(dest)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _parse_text_line(line: str) -> Tuple[str, str]:
|
|
173
|
+
"""Parse a LibriSpeech transcript line into (utterance_id, text)."""
|
|
174
|
+
left, _, right = line.partition(" ")
|
|
175
|
+
return left, right.strip()
|
torchrir/datasets/template.py
CHANGED
|
@@ -34,7 +34,9 @@ class TemplateDataset(BaseDataset):
|
|
|
34
34
|
protocol intact.
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def __init__(
|
|
37
|
+
def __init__(
|
|
38
|
+
self, root: Path, speaker: str = "default", download: bool = False
|
|
39
|
+
) -> None:
|
|
38
40
|
self.root = Path(root)
|
|
39
41
|
self.speaker = speaker
|
|
40
42
|
if download:
|
torchrir/datasets/utils.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
"""Dataset-agnostic utilities."""
|
|
4
4
|
|
|
5
5
|
import random
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
from typing import Callable, List, Optional, Sequence, Tuple
|
|
7
8
|
|
|
8
9
|
import torch
|
|
@@ -10,7 +11,9 @@ import torch
|
|
|
10
11
|
from .base import BaseDataset, SentenceLike
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
def choose_speakers(
|
|
14
|
+
def choose_speakers(
|
|
15
|
+
dataset: BaseDataset, num_sources: int, rng: random.Random
|
|
16
|
+
) -> List[str]:
|
|
14
17
|
"""Select unique speakers for the requested number of sources.
|
|
15
18
|
|
|
16
19
|
Example:
|
|
@@ -89,4 +92,23 @@ def load_dataset_sources(
|
|
|
89
92
|
info.append((speaker, utterance_ids))
|
|
90
93
|
|
|
91
94
|
stacked = torch.stack(signals, dim=0)
|
|
95
|
+
if fs is None:
|
|
96
|
+
raise RuntimeError("no audio loaded from dataset sources")
|
|
92
97
|
return stacked, int(fs), info
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
101
|
+
"""Load a wav/flac file and return mono audio and sample rate.
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
>>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
|
|
105
|
+
"""
|
|
106
|
+
import soundfile as sf
|
|
107
|
+
|
|
108
|
+
audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
|
|
109
|
+
audio_t = torch.from_numpy(audio)
|
|
110
|
+
if audio_t.shape[1] > 1:
|
|
111
|
+
audio_t = audio_t.mean(dim=1)
|
|
112
|
+
else:
|
|
113
|
+
audio_t = audio_t.squeeze(1)
|
|
114
|
+
return audio_t, sample_rate
|
torchrir/dynamic.py
CHANGED
|
@@ -44,7 +44,9 @@ class DynamicConvolver:
|
|
|
44
44
|
if self.hop is None:
|
|
45
45
|
raise ValueError("hop must be provided for hop mode")
|
|
46
46
|
return _convolve_dynamic_hop(signal, rirs, self.hop)
|
|
47
|
-
return _convolve_dynamic_trajectory(
|
|
47
|
+
return _convolve_dynamic_trajectory(
|
|
48
|
+
signal, rirs, timestamps=self.timestamps, fs=self.fs
|
|
49
|
+
)
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
def _convolve_dynamic_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
|
torchrir/plotting.py
CHANGED
|
@@ -92,7 +92,9 @@ def plot_scene_dynamic(
|
|
|
92
92
|
return ax
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def _setup_axes(
|
|
95
|
+
def _setup_axes(
|
|
96
|
+
ax: Any | None, room: Room | Sequence[float] | Tensor
|
|
97
|
+
) -> tuple[Any, Any]:
|
|
96
98
|
"""Create 2D/3D axes based on room dimension."""
|
|
97
99
|
import matplotlib.pyplot as plt
|
|
98
100
|
|
|
@@ -131,8 +133,9 @@ def _draw_room_2d(ax: Any, size: Tensor) -> None:
|
|
|
131
133
|
"""Draw a 2D rectangular room."""
|
|
132
134
|
import matplotlib.patches as patches
|
|
133
135
|
|
|
134
|
-
rect = patches.Rectangle(
|
|
135
|
-
|
|
136
|
+
rect = patches.Rectangle(
|
|
137
|
+
(0.0, 0.0), size[0].item(), size[1].item(), fill=False, edgecolor="black"
|
|
138
|
+
)
|
|
136
139
|
ax.add_patch(rect)
|
|
137
140
|
ax.set_xlim(0, size[0].item())
|
|
138
141
|
ax.set_ylim(0, size[1].item())
|
|
@@ -186,7 +189,9 @@ def _draw_room_3d(ax: Any, size: Tensor) -> None:
|
|
|
186
189
|
ax.set_zlabel("z")
|
|
187
190
|
|
|
188
191
|
|
|
189
|
-
def _extract_positions(
|
|
192
|
+
def _extract_positions(
|
|
193
|
+
entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None
|
|
194
|
+
) -> Tensor:
|
|
190
195
|
"""Extract positions from Source/MicrophoneArray or raw tensor."""
|
|
191
196
|
if isinstance(entity, (Source, MicrophoneArray)):
|
|
192
197
|
pos = entity.positions
|
|
@@ -211,7 +216,9 @@ def _scatter_positions(
|
|
|
211
216
|
return
|
|
212
217
|
dim = positions.shape[1]
|
|
213
218
|
if dim == 2:
|
|
214
|
-
ax.scatter(
|
|
219
|
+
ax.scatter(
|
|
220
|
+
positions[:, 0], positions[:, 1], label=label, marker=marker, color=color
|
|
221
|
+
)
|
|
215
222
|
else:
|
|
216
223
|
ax.scatter(
|
|
217
224
|
positions[:, 0],
|
|
@@ -300,4 +307,4 @@ def _is_moving(traj: Tensor, positions: Tensor, *, tol: float = 1e-6) -> bool:
|
|
|
300
307
|
if traj.numel() == 0:
|
|
301
308
|
return False
|
|
302
309
|
pos0 = positions.unsqueeze(0).expand_as(traj)
|
|
303
|
-
return torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item()
|
|
310
|
+
return bool(torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item())
|
torchrir/plotting_utils.py
CHANGED
|
@@ -127,7 +127,10 @@ def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
|
|
|
127
127
|
return pos
|
|
128
128
|
|
|
129
129
|
|
|
130
|
-
def _traj_steps(
|
|
130
|
+
def _traj_steps(
|
|
131
|
+
src_traj: Optional[torch.Tensor | Sequence],
|
|
132
|
+
mic_traj: Optional[torch.Tensor | Sequence],
|
|
133
|
+
) -> int:
|
|
131
134
|
"""Infer the number of trajectory steps."""
|
|
132
135
|
if src_traj is not None:
|
|
133
136
|
return int(_to_cpu(src_traj).shape[0])
|
torchrir/room.py
CHANGED
|
@@ -65,7 +65,7 @@ class Source:
|
|
|
65
65
|
"""Source container with positions and optional orientation.
|
|
66
66
|
|
|
67
67
|
Example:
|
|
68
|
-
>>> sources = Source.
|
|
68
|
+
>>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
69
69
|
"""
|
|
70
70
|
|
|
71
71
|
positions: Tensor
|
|
@@ -82,24 +82,6 @@ class Source:
|
|
|
82
82
|
"""Return a new Source with updated fields."""
|
|
83
83
|
return replace(self, **kwargs)
|
|
84
84
|
|
|
85
|
-
@classmethod
|
|
86
|
-
def positions(
|
|
87
|
-
cls,
|
|
88
|
-
positions: Sequence[Sequence[float]] | Tensor,
|
|
89
|
-
*,
|
|
90
|
-
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
91
|
-
device: Optional[torch.device | str] = None,
|
|
92
|
-
dtype: Optional[torch.dtype] = None,
|
|
93
|
-
) -> "Source":
|
|
94
|
-
"""Construct a Source from positions.
|
|
95
|
-
|
|
96
|
-
Example:
|
|
97
|
-
>>> sources = Source.positions([[1.0, 2.0, 1.5]])
|
|
98
|
-
"""
|
|
99
|
-
return cls.from_positions(
|
|
100
|
-
positions, orientation=orientation, device=device, dtype=dtype
|
|
101
|
-
)
|
|
102
|
-
|
|
103
85
|
@classmethod
|
|
104
86
|
def from_positions(
|
|
105
87
|
cls,
|
|
@@ -122,7 +104,7 @@ class MicrophoneArray:
|
|
|
122
104
|
"""Microphone array container.
|
|
123
105
|
|
|
124
106
|
Example:
|
|
125
|
-
>>> mics = MicrophoneArray.
|
|
107
|
+
>>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
126
108
|
"""
|
|
127
109
|
|
|
128
110
|
positions: Tensor
|
|
@@ -139,24 +121,6 @@ class MicrophoneArray:
|
|
|
139
121
|
"""Return a new MicrophoneArray with updated fields."""
|
|
140
122
|
return replace(self, **kwargs)
|
|
141
123
|
|
|
142
|
-
@classmethod
|
|
143
|
-
def positions(
|
|
144
|
-
cls,
|
|
145
|
-
positions: Sequence[Sequence[float]] | Tensor,
|
|
146
|
-
*,
|
|
147
|
-
orientation: Optional[Sequence[float] | Tensor] = None,
|
|
148
|
-
device: Optional[torch.device | str] = None,
|
|
149
|
-
dtype: Optional[torch.dtype] = None,
|
|
150
|
-
) -> "MicrophoneArray":
|
|
151
|
-
"""Construct a MicrophoneArray from positions.
|
|
152
|
-
|
|
153
|
-
Example:
|
|
154
|
-
>>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
|
|
155
|
-
"""
|
|
156
|
-
return cls.from_positions(
|
|
157
|
-
positions, orientation=orientation, device=device, dtype=dtype
|
|
158
|
-
)
|
|
159
|
-
|
|
160
124
|
@classmethod
|
|
161
125
|
def from_positions(
|
|
162
126
|
cls,
|
torchrir/scene_utils.py
CHANGED
|
@@ -32,7 +32,9 @@ def sample_positions(
|
|
|
32
32
|
return torch.tensor(coords, dtype=torch.float32)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def linear_trajectory(
|
|
35
|
+
def linear_trajectory(
|
|
36
|
+
start: torch.Tensor, end: torch.Tensor, steps: int
|
|
37
|
+
) -> torch.Tensor:
|
|
36
38
|
"""Create a linear trajectory between start and end.
|
|
37
39
|
|
|
38
40
|
Example:
|
|
@@ -58,7 +60,9 @@ def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.
|
|
|
58
60
|
return torch.stack([left, right], dim=0)
|
|
59
61
|
|
|
60
62
|
|
|
61
|
-
def clamp_positions(
|
|
63
|
+
def clamp_positions(
|
|
64
|
+
positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1
|
|
65
|
+
) -> torch.Tensor:
|
|
62
66
|
"""Clamp positions to remain inside the room with a margin.
|
|
63
67
|
|
|
64
68
|
Example:
|
torchrir/signal.py
CHANGED
|
@@ -117,9 +117,9 @@ def _convolve_dynamic_rir_trajectory(
|
|
|
117
117
|
else:
|
|
118
118
|
step_fs = n_samples / t_steps
|
|
119
119
|
ts_dtype = torch.float32 if signal.device.type == "mps" else torch.float64
|
|
120
|
-
w_ini = (
|
|
121
|
-
torch.
|
|
122
|
-
)
|
|
120
|
+
w_ini = (
|
|
121
|
+
torch.arange(t_steps, device=signal.device, dtype=ts_dtype) * step_fs
|
|
122
|
+
).to(torch.long)
|
|
123
123
|
|
|
124
124
|
w_ini = torch.cat(
|
|
125
125
|
[w_ini, torch.tensor([n_samples], device=signal.device, dtype=torch.long)]
|
|
@@ -132,14 +132,18 @@ def _convolve_dynamic_rir_trajectory(
|
|
|
132
132
|
)
|
|
133
133
|
|
|
134
134
|
max_len = int(w_len.max().item())
|
|
135
|
-
segments = torch.zeros(
|
|
135
|
+
segments = torch.zeros(
|
|
136
|
+
(t_steps, n_src, max_len), dtype=signal.dtype, device=signal.device
|
|
137
|
+
)
|
|
136
138
|
for t in range(t_steps):
|
|
137
139
|
start = int(w_ini[t].item())
|
|
138
140
|
end = int(w_ini[t + 1].item())
|
|
139
141
|
if end > start:
|
|
140
142
|
segments[t, :, : end - start] = signal[:, start:end]
|
|
141
143
|
|
|
142
|
-
out = torch.zeros(
|
|
144
|
+
out = torch.zeros(
|
|
145
|
+
(n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device
|
|
146
|
+
)
|
|
143
147
|
|
|
144
148
|
for t in range(t_steps):
|
|
145
149
|
seg_len = int(w_len[t].item())
|
|
@@ -166,7 +170,9 @@ def _convolve_dynamic_rir_trajectory_batched(
|
|
|
166
170
|
"""GPU-friendly batched trajectory convolution using FFT."""
|
|
167
171
|
n_samples = signal.shape[1]
|
|
168
172
|
t_steps, n_src, n_mic, rir_len = rirs.shape
|
|
169
|
-
out = torch.zeros(
|
|
173
|
+
out = torch.zeros(
|
|
174
|
+
(n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device
|
|
175
|
+
)
|
|
170
176
|
|
|
171
177
|
for t0 in range(0, t_steps, chunk_size):
|
|
172
178
|
t1 = min(t0 + chunk_size, t_steps)
|
|
@@ -174,7 +180,9 @@ def _convolve_dynamic_rir_trajectory_batched(
|
|
|
174
180
|
max_len = int(lengths.max().item())
|
|
175
181
|
if max_len == 0:
|
|
176
182
|
continue
|
|
177
|
-
segments = torch.zeros(
|
|
183
|
+
segments = torch.zeros(
|
|
184
|
+
(t1 - t0, n_src, max_len), dtype=signal.dtype, device=signal.device
|
|
185
|
+
)
|
|
178
186
|
for idx, t in enumerate(range(t0, t1)):
|
|
179
187
|
start = int(w_ini[t].item())
|
|
180
188
|
end = int(w_ini[t + 1].item())
|
|
@@ -190,7 +198,9 @@ def _convolve_dynamic_rir_trajectory_batched(
|
|
|
190
198
|
dtype=signal.dtype,
|
|
191
199
|
device=signal.device,
|
|
192
200
|
)
|
|
193
|
-
conv = torch.fft.irfft(
|
|
201
|
+
conv = torch.fft.irfft(
|
|
202
|
+
seg_f[:, :, None, :] * rir_f, n=fft_len, dim=-1, out=conv_out
|
|
203
|
+
)
|
|
194
204
|
conv = conv[..., :conv_len]
|
|
195
205
|
conv_sum = conv.sum(dim=1)
|
|
196
206
|
|
|
@@ -199,7 +209,9 @@ def _convolve_dynamic_rir_trajectory_batched(
|
|
|
199
209
|
if seg_len == 0:
|
|
200
210
|
continue
|
|
201
211
|
start = int(w_ini[t].item())
|
|
202
|
-
out[:, start : start + seg_len + rir_len - 1] += conv_sum[
|
|
212
|
+
out[:, start : start + seg_len + rir_len - 1] += conv_sum[
|
|
213
|
+
idx, :, : seg_len + rir_len - 1
|
|
214
|
+
]
|
|
203
215
|
|
|
204
216
|
return out.squeeze(0) if n_mic == 1 else out
|
|
205
217
|
|
|
@@ -221,7 +233,9 @@ def _ensure_static_rirs(rirs: Tensor) -> Tensor:
|
|
|
221
233
|
return rirs.view(1, rirs.shape[0], rirs.shape[1])
|
|
222
234
|
if rirs.ndim == 3:
|
|
223
235
|
return rirs
|
|
224
|
-
raise ValueError(
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"rirs must have shape (rir_len,), (n_mic, rir_len), or (n_src, n_mic, rir_len)"
|
|
238
|
+
)
|
|
225
239
|
|
|
226
240
|
|
|
227
241
|
def _ensure_dynamic_rirs(rirs: Tensor, signal: Tensor) -> Tensor:
|
torchrir/simulators.py
CHANGED
|
@@ -18,7 +18,9 @@ from .scene import Scene
|
|
|
18
18
|
class RIRSimulator(Protocol):
|
|
19
19
|
"""Strategy interface for RIR simulation backends."""
|
|
20
20
|
|
|
21
|
-
def simulate(
|
|
21
|
+
def simulate(
|
|
22
|
+
self, scene: Scene, config: SimulationConfig | None = None
|
|
23
|
+
) -> RIRResult:
|
|
22
24
|
"""Run a simulation and return the result."""
|
|
23
25
|
|
|
24
26
|
|
|
@@ -30,7 +32,9 @@ class ISMSimulator:
|
|
|
30
32
|
>>> result = ISMSimulator().simulate(scene, config)
|
|
31
33
|
"""
|
|
32
34
|
|
|
33
|
-
def simulate(
|
|
35
|
+
def simulate(
|
|
36
|
+
self, scene: Scene, config: SimulationConfig | None = None
|
|
37
|
+
) -> RIRResult:
|
|
34
38
|
scene.validate()
|
|
35
39
|
cfg = config or default_config()
|
|
36
40
|
if scene.is_dynamic():
|
|
@@ -71,7 +75,9 @@ class RayTracingSimulator:
|
|
|
71
75
|
reuse Scene/SimulationConfig for inputs and keep output shape parity.
|
|
72
76
|
"""
|
|
73
77
|
|
|
74
|
-
def simulate(
|
|
78
|
+
def simulate(
|
|
79
|
+
self, scene: Scene, config: SimulationConfig | None = None
|
|
80
|
+
) -> RIRResult:
|
|
75
81
|
raise NotImplementedError("RayTracingSimulator is not implemented yet")
|
|
76
82
|
|
|
77
83
|
|
|
@@ -86,5 +92,7 @@ class FDTDSimulator:
|
|
|
86
92
|
RIRResult with the same metadata contract as ISM.
|
|
87
93
|
"""
|
|
88
94
|
|
|
89
|
-
def simulate(
|
|
95
|
+
def simulate(
|
|
96
|
+
self, scene: Scene, config: SimulationConfig | None = None
|
|
97
|
+
) -> RIRResult:
|
|
90
98
|
raise NotImplementedError("FDTDSimulator is not implemented yet")
|
torchrir/utils.py
CHANGED
|
@@ -15,7 +15,7 @@ _DEF_SPEED_OF_SOUND = 343.0
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def as_tensor(
|
|
18
|
-
value: Tensor | Iterable[float] | float | int,
|
|
18
|
+
value: Tensor | Iterable[float] | Iterable[Iterable[float]] | float | int,
|
|
19
19
|
*,
|
|
20
20
|
device: Optional[torch.device | str] = None,
|
|
21
21
|
dtype: Optional[torch.dtype] = None,
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchrir
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
|
|
5
|
+
Project-URL: Repository, https://github.com/taishi-n/torchrir
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
License-File: NOTICE
|
|
10
|
+
Requires-Dist: numpy>=2.2.6
|
|
11
|
+
Requires-Dist: torch>=2.10.0
|
|
12
|
+
Dynamic: license-file
|
|
13
|
+
|
|
14
|
+
# TorchRIR
|
|
15
|
+
|
|
16
|
+
PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
|
|
17
|
+
This project has been substantially assisted by AI using Codex.
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
```bash
|
|
21
|
+
pip install torchrir
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Examples
|
|
25
|
+
- `examples/static.py`: fixed sources/mics with binaural output.
|
|
26
|
+
`uv run python examples/static.py --plot`
|
|
27
|
+
- `examples/dynamic_src.py`: moving sources, fixed mics.
|
|
28
|
+
`uv run python examples/dynamic_src.py --plot`
|
|
29
|
+
- `examples/dynamic_mic.py`: fixed sources, moving mics.
|
|
30
|
+
`uv run python examples/dynamic_mic.py --plot`
|
|
31
|
+
- `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
|
|
32
|
+
`uv run python examples/cli.py --mode static --plot`
|
|
33
|
+
- `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
|
|
34
|
+
`uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
|
|
35
|
+
- `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
|
|
36
|
+
`uv run python examples/benchmark_device.py --dynamic`
|
|
37
|
+
|
|
38
|
+
## Core API Overview
|
|
39
|
+
- Geometry: `Room`, `Source`, `MicrophoneArray`
|
|
40
|
+
- Static RIR: `simulate_rir`
|
|
41
|
+
- Dynamic RIR: `simulate_dynamic_rir`
|
|
42
|
+
- Dynamic convolution: `DynamicConvolver`
|
|
43
|
+
- Metadata export: `build_metadata`, `save_metadata_json`
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
|
|
47
|
+
|
|
48
|
+
room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
49
|
+
sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
50
|
+
mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
51
|
+
|
|
52
|
+
rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
|
|
53
|
+
# For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
|
|
54
|
+
# y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
For detailed documentation, see the docs under `docs/` and Read the Docs.
|
|
58
|
+
|
|
59
|
+
## Future Work
|
|
60
|
+
- Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
|
|
61
|
+
- CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
|
|
62
|
+
- Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
|
|
63
|
+
- Add regression tests comparing generated RIRs against gpuRIR outputs.
|
|
64
|
+
|
|
65
|
+
## Related Libraries
|
|
66
|
+
- [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
|
|
67
|
+
- [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
|
|
68
|
+
- [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
|
|
69
|
+
- [das-generator](https://github.com/ehabets/das-generator)
|
|
70
|
+
- [rir-generator](https://github.com/audiolabs/rir-generator)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
torchrir/__init__.py,sha256=MTlouAErvB7IyM4pcmhEN1U0KQsZBglWxUoHxnZDa5U,2615
|
|
2
|
+
torchrir/animation.py,sha256=x3Y-BLz3J6DQNmoDIjbMEgGfng2yavJFLyQEmRCSpQU,6391
|
|
3
|
+
torchrir/config.py,sha256=PsZdDIS3p4jepeNSHyd69aSD9QlOEdpG9v1SAXlZ_Fg,2295
|
|
4
|
+
torchrir/core.py,sha256=Ug5thts1rXvCpdq9twVHz72oWygmk5J6gliloozHKL4,31704
|
|
5
|
+
torchrir/directivity.py,sha256=v_t37YgeXF_IYzbnrk0TCs1npb_0yKR7zHiG8XV3V4w,1259
|
|
6
|
+
torchrir/dynamic.py,sha256=01JHMxhORdcz93J-YaMIeSLo7k2tHrZke8llPHHXwZg,2153
|
|
7
|
+
torchrir/logging_utils.py,sha256=s4jDSSDoHT0HKeplDUpGMsdeBij4eibLSpaaAPzkB68,2146
|
|
8
|
+
torchrir/metadata.py,sha256=cwoXrr_yE2bQRUPnJe6p7POMPCWa9_oabCtp--WqBE8,6958
|
|
9
|
+
torchrir/plotting.py,sha256=TM1LxPitZq5KXdNe1GfUCOnzFOzerGhWIFblzIz142A,8170
|
|
10
|
+
torchrir/plotting_utils.py,sha256=Kg3TCLqEq_lxVQkYHI6vyz_6oG3Ic_Z8H9gZN-39QeI,5180
|
|
11
|
+
torchrir/results.py,sha256=-HczEfr2u91BNb1xbrIGKCj0G3yzy7l_fmUMUeKbGRw,614
|
|
12
|
+
torchrir/room.py,sha256=zFnEzw0Rr1NP9IUc3iNTInyoq6t3X-0yOyUtDnsLSPk,4325
|
|
13
|
+
torchrir/scene.py,sha256=GuHuCspakAUOT81_ArTqaZbmBX0ApoJuCKTaZ21wGis,2435
|
|
14
|
+
torchrir/scene_utils.py,sha256=2La5dtjxYdINX315VXRRJMJK9oaR2rY0xHmDLjZma8M,2140
|
|
15
|
+
torchrir/signal.py,sha256=M0BpKDBqrfOmCHIJ_dvl-C3uKdFpXLDqtSIU115jsME,8383
|
|
16
|
+
torchrir/simulators.py,sha256=NCl8Ptv2TGdBpNLwAb3nigT77On-BLIANtc2ivgKasw,3131
|
|
17
|
+
torchrir/utils.py,sha256=2oE-JzAtkW5qdRds2Y5R5lbSyNZl_9piFXd6xOLzjxM,10680
|
|
18
|
+
torchrir/datasets/__init__.py,sha256=NS4zQas9YdsuDv8KQtTKmIJmS6mxMRxQk2xGzglbgUw,853
|
|
19
|
+
torchrir/datasets/base.py,sha256=LfdXO-NGCBtzaAqAeVxo8XuV5ieU6Vl91woqAHymsT8,1970
|
|
20
|
+
torchrir/datasets/cmu_arctic.py,sha256=7IFv33RBBu044kTMO6nKUmziml2gjILUgnpL262rAU8,6593
|
|
21
|
+
torchrir/datasets/collate.py,sha256=gZfaHog0gtb8Avg6qsDZ1m4yoKkYkcuwmty1RtLYhhI,2542
|
|
22
|
+
torchrir/datasets/librispeech.py,sha256=XKlAm0Z0coipKuqR9Z8X8l9puXVYz7zb6yE3PCuMUrI,6019
|
|
23
|
+
torchrir/datasets/template.py,sha256=pHAKj5E7Gehfk9pqdTsFQjiDV1OK3hSZJIbYutd-E4c,2090
|
|
24
|
+
torchrir/datasets/utils.py,sha256=OCYd7Dbr2hsqBbiHE1LHPMYdqwe2YfDw0tpRMfND0Og,3790
|
|
25
|
+
torchrir-0.2.0.dist-info/licenses/LICENSE,sha256=5vS_7WTsMEw_QQHEPQ_WCwovJXEgmxoEwcwOI-9VbXI,10766
|
|
26
|
+
torchrir-0.2.0.dist-info/licenses/NOTICE,sha256=SRs_q-ZqoVF9_YuuedZOvVBk01jV7YQAeF8rRvlRg0s,118
|
|
27
|
+
torchrir-0.2.0.dist-info/METADATA,sha256=0f5EhK2SWmP_RxZ2O8y7AIDP_dFaKaf9X_n0mt7045k,3077
|
|
28
|
+
torchrir-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
29
|
+
torchrir-0.2.0.dist-info/top_level.txt,sha256=aIFwntowJjvm7rZk480HymC3ipDo1g-9hEbNY1wF-Oo,9
|
|
30
|
+
torchrir-0.2.0.dist-info/RECORD,,
|