torchrir 0.1.2__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir-0.2.0/PKG-INFO +70 -0
- torchrir-0.2.0/README.md +57 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/pyproject.toml +12 -1
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/__init__.py +16 -1
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/animation.py +17 -14
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/core.py +176 -35
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/__init__.py +9 -3
- torchrir-0.2.0/src/torchrir/datasets/base.py +67 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/cmu_arctic.py +9 -20
- torchrir-0.2.0/src/torchrir/datasets/collate.py +90 -0
- torchrir-0.2.0/src/torchrir/datasets/librispeech.py +175 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/template.py +3 -1
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/utils.py +23 -1
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/dynamic.py +3 -1
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/plotting.py +13 -6
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/plotting_utils.py +4 -1
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/room.py +2 -38
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/scene_utils.py +6 -2
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/signal.py +24 -10
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/simulators.py +12 -4
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/utils.py +1 -1
- torchrir-0.2.0/src/torchrir.egg-info/PKG-INFO +70 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/SOURCES.txt +2 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_compare_pyroomacoustics.py +5 -3
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_core.py +6 -6
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_device_parity.py +4 -4
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_plotting.py +1 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_scene.py +25 -13
- torchrir-0.1.2/PKG-INFO +0 -271
- torchrir-0.1.2/README.md +0 -258
- torchrir-0.1.2/src/torchrir/datasets/base.py +0 -27
- torchrir-0.1.2/src/torchrir.egg-info/PKG-INFO +0 -271
- {torchrir-0.1.2 → torchrir-0.2.0}/LICENSE +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/NOTICE +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/setup.cfg +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/config.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/directivity.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/logging_utils.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/metadata.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/results.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/scene.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/dependency_links.txt +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/requires.txt +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/top_level.txt +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_room.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_signal.py +0 -0
- {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_utils.py +0 -0
torchrir-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchrir
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
|
|
5
|
+
Project-URL: Repository, https://github.com/taishi-n/torchrir
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
License-File: NOTICE
|
|
10
|
+
Requires-Dist: numpy>=2.2.6
|
|
11
|
+
Requires-Dist: torch>=2.10.0
|
|
12
|
+
Dynamic: license-file
|
|
13
|
+
|
|
14
|
+
# TorchRIR
|
|
15
|
+
|
|
16
|
+
PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
|
|
17
|
+
This project has been substantially assisted by AI using Codex.
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
```bash
|
|
21
|
+
pip install torchrir
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Examples
|
|
25
|
+
- `examples/static.py`: fixed sources/mics with binaural output.
|
|
26
|
+
`uv run python examples/static.py --plot`
|
|
27
|
+
- `examples/dynamic_src.py`: moving sources, fixed mics.
|
|
28
|
+
`uv run python examples/dynamic_src.py --plot`
|
|
29
|
+
- `examples/dynamic_mic.py`: fixed sources, moving mics.
|
|
30
|
+
`uv run python examples/dynamic_mic.py --plot`
|
|
31
|
+
- `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
|
|
32
|
+
`uv run python examples/cli.py --mode static --plot`
|
|
33
|
+
- `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
|
|
34
|
+
`uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
|
|
35
|
+
- `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
|
|
36
|
+
`uv run python examples/benchmark_device.py --dynamic`
|
|
37
|
+
|
|
38
|
+
## Core API Overview
|
|
39
|
+
- Geometry: `Room`, `Source`, `MicrophoneArray`
|
|
40
|
+
- Static RIR: `simulate_rir`
|
|
41
|
+
- Dynamic RIR: `simulate_dynamic_rir`
|
|
42
|
+
- Dynamic convolution: `DynamicConvolver`
|
|
43
|
+
- Metadata export: `build_metadata`, `save_metadata_json`
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
|
|
47
|
+
|
|
48
|
+
room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
49
|
+
sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
50
|
+
mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
51
|
+
|
|
52
|
+
rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
|
|
53
|
+
# For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
|
|
54
|
+
# y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
For detailed documentation, see the docs under `docs/` and Read the Docs.
|
|
58
|
+
|
|
59
|
+
## Future Work
|
|
60
|
+
- Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
|
|
61
|
+
- CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
|
|
62
|
+
- Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
|
|
63
|
+
- Add regression tests comparing generated RIRs against gpuRIR outputs.
|
|
64
|
+
|
|
65
|
+
## Related Libraries
|
|
66
|
+
- [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
|
|
67
|
+
- [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
|
|
68
|
+
- [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
|
|
69
|
+
- [das-generator](https://github.com/ehabets/das-generator)
|
|
70
|
+
- [rir-generator](https://github.com/audiolabs/rir-generator)
|
torchrir-0.2.0/README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# TorchRIR
|
|
2
|
+
|
|
3
|
+
PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
|
|
4
|
+
This project has been substantially assisted by AI using Codex.
|
|
5
|
+
|
|
6
|
+
## Installation
|
|
7
|
+
```bash
|
|
8
|
+
pip install torchrir
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Examples
|
|
12
|
+
- `examples/static.py`: fixed sources/mics with binaural output.
|
|
13
|
+
`uv run python examples/static.py --plot`
|
|
14
|
+
- `examples/dynamic_src.py`: moving sources, fixed mics.
|
|
15
|
+
`uv run python examples/dynamic_src.py --plot`
|
|
16
|
+
- `examples/dynamic_mic.py`: fixed sources, moving mics.
|
|
17
|
+
`uv run python examples/dynamic_mic.py --plot`
|
|
18
|
+
- `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
|
|
19
|
+
`uv run python examples/cli.py --mode static --plot`
|
|
20
|
+
- `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
|
|
21
|
+
`uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
|
|
22
|
+
- `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
|
|
23
|
+
`uv run python examples/benchmark_device.py --dynamic`
|
|
24
|
+
|
|
25
|
+
## Core API Overview
|
|
26
|
+
- Geometry: `Room`, `Source`, `MicrophoneArray`
|
|
27
|
+
- Static RIR: `simulate_rir`
|
|
28
|
+
- Dynamic RIR: `simulate_dynamic_rir`
|
|
29
|
+
- Dynamic convolution: `DynamicConvolver`
|
|
30
|
+
- Metadata export: `build_metadata`, `save_metadata_json`
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
|
|
34
|
+
|
|
35
|
+
room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
36
|
+
sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
37
|
+
mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
38
|
+
|
|
39
|
+
rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
|
|
40
|
+
# For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
|
|
41
|
+
# y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
For detailed documentation, see the docs under `docs/` and Read the Docs.
|
|
45
|
+
|
|
46
|
+
## Future Work
|
|
47
|
+
- Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
|
|
48
|
+
- CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
|
|
49
|
+
- Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
|
|
50
|
+
- Add regression tests comparing generated RIRs against gpuRIR outputs.
|
|
51
|
+
|
|
52
|
+
## Related Libraries
|
|
53
|
+
- [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
|
|
54
|
+
- [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
|
|
55
|
+
- [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
|
|
56
|
+
- [das-generator](https://github.com/ehabets/das-generator)
|
|
57
|
+
- [rir-generator](https://github.com/audiolabs/rir-generator)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "torchrir"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.2.0"
|
|
4
4
|
description = "PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.10"
|
|
@@ -14,8 +14,10 @@ Repository = "https://github.com/taishi-n/torchrir"
|
|
|
14
14
|
|
|
15
15
|
[dependency-groups]
|
|
16
16
|
dev = [
|
|
17
|
+
"commitizen>=3.29.0",
|
|
17
18
|
"git-cliff>=2.10.1",
|
|
18
19
|
"matplotlib>=3.10.8",
|
|
20
|
+
"ruff>=0.12.2",
|
|
19
21
|
"pillow>=11.2.1",
|
|
20
22
|
"pyroomacoustics>=0.9.0",
|
|
21
23
|
"pytest>=9.0.2",
|
|
@@ -23,4 +25,13 @@ dev = [
|
|
|
23
25
|
"sphinx>=7.0,<8.2.3",
|
|
24
26
|
"sphinx-rtd-theme>=2.0.0",
|
|
25
27
|
"myst-parser>=2.0,<4.0",
|
|
28
|
+
"ty>=0.0.14,<0.1",
|
|
26
29
|
]
|
|
30
|
+
|
|
31
|
+
[tool.commitizen]
|
|
32
|
+
name = "cz_conventional_commits"
|
|
33
|
+
tag_format = "v$version"
|
|
34
|
+
version_scheme = "pep440"
|
|
35
|
+
version_provider = "pep621"
|
|
36
|
+
update_changelog_on_bump = true
|
|
37
|
+
changelog_file = "CHANGELOG.md"
|
|
@@ -18,6 +18,11 @@ from .datasets import (
|
|
|
18
18
|
CmuArcticDataset,
|
|
19
19
|
CmuArcticSentence,
|
|
20
20
|
choose_speakers,
|
|
21
|
+
CollateBatch,
|
|
22
|
+
collate_dataset_items,
|
|
23
|
+
DatasetItem,
|
|
24
|
+
LibriSpeechDataset,
|
|
25
|
+
LibriSpeechSentence,
|
|
21
26
|
list_cmu_arctic_speakers,
|
|
22
27
|
SentenceLike,
|
|
23
28
|
load_dataset_sources,
|
|
@@ -26,7 +31,12 @@ from .datasets import (
|
|
|
26
31
|
load_wav_mono,
|
|
27
32
|
save_wav,
|
|
28
33
|
)
|
|
29
|
-
from .scene_utils import
|
|
34
|
+
from .scene_utils import (
|
|
35
|
+
binaural_mic_positions,
|
|
36
|
+
clamp_positions,
|
|
37
|
+
linear_trajectory,
|
|
38
|
+
sample_positions,
|
|
39
|
+
)
|
|
30
40
|
from .utils import (
|
|
31
41
|
att2t_SabineEstimation,
|
|
32
42
|
att2t_sabine_estimation,
|
|
@@ -56,6 +66,11 @@ __all__ = [
|
|
|
56
66
|
"CmuArcticDataset",
|
|
57
67
|
"CmuArcticSentence",
|
|
58
68
|
"choose_speakers",
|
|
69
|
+
"CollateBatch",
|
|
70
|
+
"collate_dataset_items",
|
|
71
|
+
"DatasetItem",
|
|
72
|
+
"LibriSpeechDataset",
|
|
73
|
+
"LibriSpeechSentence",
|
|
59
74
|
"DynamicConvolver",
|
|
60
75
|
"estimate_beta_from_t60",
|
|
61
76
|
"estimate_t60_from_beta",
|
|
@@ -104,15 +104,15 @@ def animate_scene_gif(
|
|
|
104
104
|
mic_lines = []
|
|
105
105
|
for _ in range(view_src_traj.shape[1]):
|
|
106
106
|
if view_dim == 2:
|
|
107
|
-
line, = ax.plot([], [], color="tab:green", alpha=0.6)
|
|
107
|
+
(line,) = ax.plot([], [], color="tab:green", alpha=0.6)
|
|
108
108
|
else:
|
|
109
|
-
line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
|
|
109
|
+
(line,) = ax.plot([], [], [], color="tab:green", alpha=0.6)
|
|
110
110
|
src_lines.append(line)
|
|
111
111
|
for _ in range(view_mic_traj.shape[1]):
|
|
112
112
|
if view_dim == 2:
|
|
113
|
-
line, = ax.plot([], [], color="tab:orange", alpha=0.6)
|
|
113
|
+
(line,) = ax.plot([], [], color="tab:orange", alpha=0.6)
|
|
114
114
|
else:
|
|
115
|
-
line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
|
|
115
|
+
(line,) = ax.plot([], [], [], color="tab:orange", alpha=0.6)
|
|
116
116
|
mic_lines.append(line)
|
|
117
117
|
|
|
118
118
|
ax.legend(loc="best")
|
|
@@ -137,15 +137,15 @@ def animate_scene_gif(
|
|
|
137
137
|
xy = mic_frame[:, m_idx, :]
|
|
138
138
|
line.set_data(xy[:, 0], xy[:, 1])
|
|
139
139
|
else:
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
src_pos_frame[:, 2],
|
|
140
|
+
setattr(
|
|
141
|
+
src_scatter,
|
|
142
|
+
"_offsets3d",
|
|
143
|
+
(src_pos_frame[:, 0], src_pos_frame[:, 1], src_pos_frame[:, 2]),
|
|
144
144
|
)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
mic_pos_frame[:, 2],
|
|
145
|
+
setattr(
|
|
146
|
+
mic_scatter,
|
|
147
|
+
"_offsets3d",
|
|
148
|
+
(mic_pos_frame[:, 0], mic_pos_frame[:, 1], mic_pos_frame[:, 2]),
|
|
149
149
|
)
|
|
150
150
|
for s_idx, line in enumerate(src_lines):
|
|
151
151
|
xyz = src_frame[:, s_idx, :]
|
|
@@ -166,7 +166,10 @@ def animate_scene_gif(
|
|
|
166
166
|
fps = frames / duration_s
|
|
167
167
|
else:
|
|
168
168
|
fps = 6.0
|
|
169
|
-
anim = animation.FuncAnimation(
|
|
170
|
-
|
|
169
|
+
anim = animation.FuncAnimation(
|
|
170
|
+
fig, _frame, frames=frames, interval=1000 / fps, blit=False
|
|
171
|
+
)
|
|
172
|
+
fps_int = None if fps is None else max(1, int(round(fps)))
|
|
173
|
+
anim.save(out_path, writer="pillow", fps=fps_int)
|
|
171
174
|
plt.close(fig)
|
|
172
175
|
return out_path
|
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
"""Core RIR simulation functions (static and dynamic)."""
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from typing import Optional, Tuple
|
|
7
8
|
|
|
8
9
|
import torch
|
|
@@ -61,8 +62,8 @@ def simulate_rir(
|
|
|
61
62
|
|
|
62
63
|
Example:
|
|
63
64
|
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
64
|
-
>>> sources = Source.
|
|
65
|
-
>>> mics = MicrophoneArray.
|
|
65
|
+
>>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
66
|
+
>>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
66
67
|
>>> rir = simulate_rir(
|
|
67
68
|
... room=room,
|
|
68
69
|
... sources=sources,
|
|
@@ -90,9 +91,9 @@ def simulate_rir(
|
|
|
90
91
|
|
|
91
92
|
if not isinstance(room, Room):
|
|
92
93
|
raise TypeError("room must be a Room instance")
|
|
93
|
-
if nsample is None and tmax is None:
|
|
94
|
-
raise ValueError("nsample or tmax must be provided")
|
|
95
94
|
if nsample is None:
|
|
95
|
+
if tmax is None:
|
|
96
|
+
raise ValueError("nsample or tmax must be provided")
|
|
96
97
|
nsample = int(math.ceil(tmax * room.fs))
|
|
97
98
|
if nsample <= 0:
|
|
98
99
|
raise ValueError("nsample must be positive")
|
|
@@ -261,6 +262,11 @@ def simulate_dynamic_rir(
|
|
|
261
262
|
|
|
262
263
|
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
263
264
|
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
265
|
+
device, dtype = infer_device_dtype(
|
|
266
|
+
src_traj, mic_traj, room.size, device=device, dtype=dtype
|
|
267
|
+
)
|
|
268
|
+
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
269
|
+
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
264
270
|
|
|
265
271
|
if src_traj.ndim == 2:
|
|
266
272
|
src_traj = src_traj.unsqueeze(1)
|
|
@@ -273,24 +279,95 @@ def simulate_dynamic_rir(
|
|
|
273
279
|
if src_traj.shape[0] != mic_traj.shape[0]:
|
|
274
280
|
raise ValueError("src_traj and mic_traj must have the same time length")
|
|
275
281
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
282
|
+
if not isinstance(room, Room):
|
|
283
|
+
raise TypeError("room must be a Room instance")
|
|
284
|
+
if nsample is None:
|
|
285
|
+
if tmax is None:
|
|
286
|
+
raise ValueError("nsample or tmax must be provided")
|
|
287
|
+
nsample = int(math.ceil(tmax * room.fs))
|
|
288
|
+
if nsample <= 0:
|
|
289
|
+
raise ValueError("nsample must be positive")
|
|
290
|
+
if max_order < 0:
|
|
291
|
+
raise ValueError("max_order must be non-negative")
|
|
292
|
+
|
|
293
|
+
room_size = as_tensor(room.size, device=device, dtype=dtype)
|
|
294
|
+
room_size = ensure_dim(room_size)
|
|
295
|
+
dim = room_size.numel()
|
|
296
|
+
if src_traj.shape[2] != dim:
|
|
297
|
+
raise ValueError("src_traj must match room dimension")
|
|
298
|
+
if mic_traj.shape[2] != dim:
|
|
299
|
+
raise ValueError("mic_traj must match room dimension")
|
|
300
|
+
|
|
301
|
+
src_ori = None
|
|
302
|
+
mic_ori = None
|
|
303
|
+
if orientation is not None:
|
|
304
|
+
if isinstance(orientation, (list, tuple)):
|
|
305
|
+
if len(orientation) != 2:
|
|
306
|
+
raise ValueError("orientation tuple must have length 2")
|
|
307
|
+
src_ori, mic_ori = orientation
|
|
308
|
+
else:
|
|
309
|
+
src_ori = orientation
|
|
310
|
+
mic_ori = orientation
|
|
311
|
+
if src_ori is not None:
|
|
312
|
+
src_ori = as_tensor(src_ori, device=device, dtype=dtype)
|
|
313
|
+
if mic_ori is not None:
|
|
314
|
+
mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
|
|
315
|
+
|
|
316
|
+
beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
|
|
317
|
+
beta = _validate_beta(beta, dim)
|
|
318
|
+
n_vec = _image_source_indices(max_order, dim, device=device, nb_img=None)
|
|
319
|
+
refl = _reflection_coefficients(n_vec, beta)
|
|
320
|
+
|
|
321
|
+
src_pattern, mic_pattern = split_directivity(directivity)
|
|
322
|
+
mic_dir = None
|
|
323
|
+
if mic_pattern != "omni":
|
|
324
|
+
if mic_ori is None:
|
|
325
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
326
|
+
mic_dir = orientation_to_unit(mic_ori, dim)
|
|
327
|
+
|
|
328
|
+
n_src = src_traj.shape[1]
|
|
329
|
+
n_mic = mic_traj.shape[1]
|
|
330
|
+
rirs = torch.zeros((src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype)
|
|
331
|
+
fdl = cfg.frac_delay_length
|
|
332
|
+
fdl2 = (fdl - 1) // 2
|
|
333
|
+
img_chunk = cfg.image_chunk_size
|
|
334
|
+
if img_chunk <= 0:
|
|
335
|
+
img_chunk = n_vec.shape[0]
|
|
336
|
+
|
|
337
|
+
src_dirs = None
|
|
338
|
+
if src_pattern != "omni":
|
|
339
|
+
if src_ori is None:
|
|
340
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
341
|
+
src_dirs = orientation_to_unit(src_ori, dim)
|
|
342
|
+
if src_dirs.ndim == 1:
|
|
343
|
+
src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
|
|
344
|
+
if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
|
|
345
|
+
raise ValueError("source orientation must match number of sources")
|
|
346
|
+
|
|
347
|
+
for start in range(0, n_vec.shape[0], img_chunk):
|
|
348
|
+
end = min(start + img_chunk, n_vec.shape[0])
|
|
349
|
+
n_vec_chunk = n_vec[start:end]
|
|
350
|
+
refl_chunk = refl[start:end]
|
|
351
|
+
sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
|
|
352
|
+
src_traj,
|
|
353
|
+
mic_traj,
|
|
354
|
+
room_size,
|
|
355
|
+
n_vec_chunk,
|
|
356
|
+
refl_chunk,
|
|
357
|
+
room,
|
|
358
|
+
fdl2,
|
|
359
|
+
src_pattern=src_pattern,
|
|
360
|
+
mic_pattern=mic_pattern,
|
|
361
|
+
src_dirs=src_dirs,
|
|
362
|
+
mic_dir=mic_dir,
|
|
291
363
|
)
|
|
292
|
-
|
|
293
|
-
|
|
364
|
+
t_steps = src_traj.shape[0]
|
|
365
|
+
sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
|
|
366
|
+
attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
|
|
367
|
+
rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
|
|
368
|
+
_accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)
|
|
369
|
+
|
|
370
|
+
return rirs
|
|
294
371
|
|
|
295
372
|
|
|
296
373
|
def _prepare_entities(
|
|
@@ -495,7 +572,11 @@ def _compute_image_contributions_batch(
|
|
|
495
572
|
if mic_pattern != "omni":
|
|
496
573
|
if mic_dir is None:
|
|
497
574
|
raise ValueError("mic orientation required for non-omni directivity")
|
|
498
|
-
mic_dir =
|
|
575
|
+
mic_dir = (
|
|
576
|
+
mic_dir[None, :, None, :]
|
|
577
|
+
if mic_dir.ndim == 2
|
|
578
|
+
else mic_dir.view(1, 1, 1, -1)
|
|
579
|
+
)
|
|
499
580
|
cos_theta = _cos_between(-vec, mic_dir)
|
|
500
581
|
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
501
582
|
|
|
@@ -503,6 +584,54 @@ def _compute_image_contributions_batch(
|
|
|
503
584
|
return sample, attenuation
|
|
504
585
|
|
|
505
586
|
|
|
587
|
+
def _compute_image_contributions_time_batch(
|
|
588
|
+
src_traj: Tensor,
|
|
589
|
+
mic_traj: Tensor,
|
|
590
|
+
room_size: Tensor,
|
|
591
|
+
n_vec: Tensor,
|
|
592
|
+
refl: Tensor,
|
|
593
|
+
room: Room,
|
|
594
|
+
fdl2: int,
|
|
595
|
+
*,
|
|
596
|
+
src_pattern: str,
|
|
597
|
+
mic_pattern: str,
|
|
598
|
+
src_dirs: Optional[Tensor],
|
|
599
|
+
mic_dir: Optional[Tensor],
|
|
600
|
+
) -> Tuple[Tensor, Tensor]:
|
|
601
|
+
"""Compute samples/attenuation for all time steps in batch."""
|
|
602
|
+
sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_traj.dtype)
|
|
603
|
+
n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_traj.dtype)
|
|
604
|
+
base = 2.0 * room_size * n
|
|
605
|
+
img = base[None, None, :, :] + sign[None, None, :, :] * src_traj[:, :, None, :]
|
|
606
|
+
vec = mic_traj[:, None, :, None, :] - img[:, :, None, :, :]
|
|
607
|
+
dist = torch.linalg.norm(vec, dim=-1)
|
|
608
|
+
dist = torch.clamp(dist, min=1e-6)
|
|
609
|
+
time = dist / room.c
|
|
610
|
+
time = time + (fdl2 / room.fs)
|
|
611
|
+
sample = time * room.fs
|
|
612
|
+
|
|
613
|
+
gain = refl.view(1, 1, 1, -1)
|
|
614
|
+
if src_pattern != "omni":
|
|
615
|
+
if src_dirs is None:
|
|
616
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
617
|
+
src_dirs_b = src_dirs[None, :, None, None, :]
|
|
618
|
+
cos_theta = _cos_between(vec, src_dirs_b)
|
|
619
|
+
gain = gain * directivity_gain(src_pattern, cos_theta)
|
|
620
|
+
if mic_pattern != "omni":
|
|
621
|
+
if mic_dir is None:
|
|
622
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
623
|
+
mic_dir_b = (
|
|
624
|
+
mic_dir[None, None, :, None, :]
|
|
625
|
+
if mic_dir.ndim == 2
|
|
626
|
+
else mic_dir.view(1, 1, 1, 1, -1)
|
|
627
|
+
)
|
|
628
|
+
cos_theta = _cos_between(-vec, mic_dir_b)
|
|
629
|
+
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
630
|
+
|
|
631
|
+
attenuation = gain / dist
|
|
632
|
+
return sample, attenuation
|
|
633
|
+
|
|
634
|
+
|
|
506
635
|
def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
|
|
507
636
|
"""Pick the correct orientation vector for a given entity index."""
|
|
508
637
|
if orientation.ndim == 0:
|
|
@@ -542,9 +671,9 @@ def _accumulate_rir(
|
|
|
542
671
|
if use_lut:
|
|
543
672
|
sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
|
|
544
673
|
|
|
545
|
-
mic_offsets = (
|
|
546
|
-
n_mic,
|
|
547
|
-
)
|
|
674
|
+
mic_offsets = (
|
|
675
|
+
torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample
|
|
676
|
+
).view(n_mic, 1, 1)
|
|
548
677
|
rir_flat = rir.view(-1)
|
|
549
678
|
|
|
550
679
|
chunk_size = cfg.accumulate_chunk_size
|
|
@@ -559,7 +688,9 @@ def _accumulate_rir(
|
|
|
559
688
|
x_off_frac = (1.0 - frac_m) * lut_gran
|
|
560
689
|
lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
|
|
561
690
|
x_off = x_off_frac - lut_gran_off.to(dtype)
|
|
562
|
-
lut_pos = lut_gran_off[..., None] + (
|
|
691
|
+
lut_pos = lut_gran_off[..., None] + (
|
|
692
|
+
n[None, None, :].to(torch.int64) * lut_gran
|
|
693
|
+
)
|
|
563
694
|
|
|
564
695
|
s0 = torch.take(sinc_lut, lut_pos)
|
|
565
696
|
s1 = torch.take(sinc_lut, lut_pos + 1)
|
|
@@ -618,9 +749,9 @@ def _accumulate_rir_batch_impl(
|
|
|
618
749
|
if use_lut:
|
|
619
750
|
sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
|
|
620
751
|
|
|
621
|
-
sm_offsets = (
|
|
622
|
-
n_sm,
|
|
623
|
-
)
|
|
752
|
+
sm_offsets = (
|
|
753
|
+
torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample
|
|
754
|
+
).view(n_sm, 1, 1)
|
|
624
755
|
rir_flat = rir.view(-1)
|
|
625
756
|
|
|
626
757
|
n_img = idx0.shape[1]
|
|
@@ -634,7 +765,9 @@ def _accumulate_rir_batch_impl(
|
|
|
634
765
|
x_off_frac = (1.0 - frac_m) * lut_gran
|
|
635
766
|
lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
|
|
636
767
|
x_off = x_off_frac - lut_gran_off.to(sample.dtype)
|
|
637
|
-
lut_pos = lut_gran_off[..., None] + (
|
|
768
|
+
lut_pos = lut_gran_off[..., None] + (
|
|
769
|
+
n[None, None, :].to(torch.int64) * lut_gran
|
|
770
|
+
)
|
|
638
771
|
|
|
639
772
|
s0 = torch.take(sinc_lut, lut_pos)
|
|
640
773
|
s1 = torch.take(sinc_lut, lut_pos + 1)
|
|
@@ -660,12 +793,13 @@ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
|
|
|
660
793
|
_FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
|
|
661
794
|
_FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
|
|
662
795
|
_FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
|
|
663
|
-
|
|
796
|
+
_AccumFn = Callable[[Tensor, Tensor, Tensor], None]
|
|
797
|
+
_ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], _AccumFn] = {}
|
|
664
798
|
|
|
665
799
|
|
|
666
800
|
def _get_accumulate_fn(
|
|
667
801
|
cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
|
|
668
|
-
) ->
|
|
802
|
+
) -> _AccumFn:
|
|
669
803
|
"""Return an accumulation function with config-bound constants."""
|
|
670
804
|
use_lut = cfg.use_lut and device.type != "mps"
|
|
671
805
|
fdl = cfg.frac_delay_length
|
|
@@ -721,7 +855,9 @@ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Te
|
|
|
721
855
|
return cached
|
|
722
856
|
|
|
723
857
|
|
|
724
|
-
def _get_sinc_lut(
|
|
858
|
+
def _get_sinc_lut(
|
|
859
|
+
fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype
|
|
860
|
+
) -> Tensor:
|
|
725
861
|
"""Create a sinc lookup table for fractional delays."""
|
|
726
862
|
key = (fdl, lut_gran, str(device), dtype)
|
|
727
863
|
cached = _SINC_LUT_CACHE.get(key)
|
|
@@ -765,7 +901,12 @@ def _apply_diffuse_tail(
|
|
|
765
901
|
|
|
766
902
|
gen = torch.Generator(device=rir.device)
|
|
767
903
|
gen.manual_seed(0 if seed is None else seed)
|
|
768
|
-
noise = torch.randn(
|
|
769
|
-
|
|
904
|
+
noise = torch.randn(
|
|
905
|
+
rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen
|
|
906
|
+
)
|
|
907
|
+
scale = (
|
|
908
|
+
torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True)
|
|
909
|
+
+ 1e-8
|
|
910
|
+
)
|
|
770
911
|
rir[..., tdiff_idx:] = noise * decay * scale
|
|
771
912
|
return rir
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""Dataset helpers for torchrir."""
|
|
2
2
|
|
|
3
|
-
from .base import BaseDataset, SentenceLike
|
|
4
|
-
from .utils import choose_speakers, load_dataset_sources
|
|
3
|
+
from .base import BaseDataset, DatasetItem, SentenceLike
|
|
4
|
+
from .utils import choose_speakers, load_dataset_sources, load_wav_mono
|
|
5
|
+
from .collate import CollateBatch, collate_dataset_items
|
|
5
6
|
from .template import TemplateDataset, TemplateSentence
|
|
7
|
+
from .librispeech import LibriSpeechDataset, LibriSpeechSentence
|
|
6
8
|
|
|
7
9
|
from .cmu_arctic import (
|
|
8
10
|
CmuArcticDataset,
|
|
9
11
|
CmuArcticSentence,
|
|
10
12
|
list_cmu_arctic_speakers,
|
|
11
|
-
load_wav_mono,
|
|
12
13
|
save_wav,
|
|
13
14
|
)
|
|
14
15
|
|
|
@@ -17,6 +18,9 @@ __all__ = [
|
|
|
17
18
|
"CmuArcticDataset",
|
|
18
19
|
"CmuArcticSentence",
|
|
19
20
|
"choose_speakers",
|
|
21
|
+
"DatasetItem",
|
|
22
|
+
"CollateBatch",
|
|
23
|
+
"collate_dataset_items",
|
|
20
24
|
"list_cmu_arctic_speakers",
|
|
21
25
|
"SentenceLike",
|
|
22
26
|
"load_dataset_sources",
|
|
@@ -24,4 +28,6 @@ __all__ = [
|
|
|
24
28
|
"save_wav",
|
|
25
29
|
"TemplateDataset",
|
|
26
30
|
"TemplateSentence",
|
|
31
|
+
"LibriSpeechDataset",
|
|
32
|
+
"LibriSpeechSentence",
|
|
27
33
|
]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Dataset protocol definitions."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional, Protocol, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SentenceLike(Protocol):
|
|
13
|
+
"""Minimal sentence interface for dataset entries."""
|
|
14
|
+
|
|
15
|
+
utterance_id: str
|
|
16
|
+
text: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class DatasetItem:
|
|
21
|
+
"""Dataset item for DataLoader consumption."""
|
|
22
|
+
|
|
23
|
+
audio: torch.Tensor
|
|
24
|
+
sample_rate: int
|
|
25
|
+
utterance_id: str
|
|
26
|
+
text: Optional[str] = None
|
|
27
|
+
speaker: Optional[str] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseDataset(Dataset[DatasetItem]):
|
|
31
|
+
"""Base dataset class compatible with torch.utils.data.Dataset."""
|
|
32
|
+
|
|
33
|
+
_sentences_cache: Optional[list[SentenceLike]] = None
|
|
34
|
+
|
|
35
|
+
def list_speakers(self) -> list[str]:
|
|
36
|
+
"""Return available speaker IDs."""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
def available_sentences(self) -> Sequence[SentenceLike]:
|
|
40
|
+
"""Return sentence entries that have audio available."""
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
44
|
+
"""Load audio for an utterance and return (audio, sample_rate)."""
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
def __len__(self) -> int:
|
|
48
|
+
return len(self._get_sentences())
|
|
49
|
+
|
|
50
|
+
def __getitem__(self, idx: int) -> DatasetItem:
|
|
51
|
+
sentences = self._get_sentences()
|
|
52
|
+
sentence = sentences[idx]
|
|
53
|
+
audio, sample_rate = self.load_wav(sentence.utterance_id)
|
|
54
|
+
speaker = getattr(self, "speaker", None)
|
|
55
|
+
text = getattr(sentence, "text", None)
|
|
56
|
+
return DatasetItem(
|
|
57
|
+
audio=audio,
|
|
58
|
+
sample_rate=sample_rate,
|
|
59
|
+
utterance_id=sentence.utterance_id,
|
|
60
|
+
text=text,
|
|
61
|
+
speaker=speaker,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _get_sentences(self) -> list[SentenceLike]:
|
|
65
|
+
if self._sentences_cache is None:
|
|
66
|
+
self._sentences_cache = list(self.available_sentences())
|
|
67
|
+
return self._sentences_cache
|