torchrir 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. torchrir-0.2.0/PKG-INFO +70 -0
  2. torchrir-0.2.0/README.md +57 -0
  3. {torchrir-0.1.2 → torchrir-0.2.0}/pyproject.toml +12 -1
  4. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/__init__.py +16 -1
  5. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/animation.py +17 -14
  6. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/core.py +176 -35
  7. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/__init__.py +9 -3
  8. torchrir-0.2.0/src/torchrir/datasets/base.py +67 -0
  9. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/cmu_arctic.py +9 -20
  10. torchrir-0.2.0/src/torchrir/datasets/collate.py +90 -0
  11. torchrir-0.2.0/src/torchrir/datasets/librispeech.py +175 -0
  12. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/template.py +3 -1
  13. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/datasets/utils.py +23 -1
  14. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/dynamic.py +3 -1
  15. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/plotting.py +13 -6
  16. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/plotting_utils.py +4 -1
  17. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/room.py +2 -38
  18. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/scene_utils.py +6 -2
  19. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/signal.py +24 -10
  20. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/simulators.py +12 -4
  21. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/utils.py +1 -1
  22. torchrir-0.2.0/src/torchrir.egg-info/PKG-INFO +70 -0
  23. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/SOURCES.txt +2 -0
  24. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_compare_pyroomacoustics.py +5 -3
  25. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_core.py +6 -6
  26. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_device_parity.py +4 -4
  27. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_plotting.py +1 -0
  28. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_scene.py +25 -13
  29. torchrir-0.1.2/PKG-INFO +0 -271
  30. torchrir-0.1.2/README.md +0 -258
  31. torchrir-0.1.2/src/torchrir/datasets/base.py +0 -27
  32. torchrir-0.1.2/src/torchrir.egg-info/PKG-INFO +0 -271
  33. {torchrir-0.1.2 → torchrir-0.2.0}/LICENSE +0 -0
  34. {torchrir-0.1.2 → torchrir-0.2.0}/NOTICE +0 -0
  35. {torchrir-0.1.2 → torchrir-0.2.0}/setup.cfg +0 -0
  36. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/config.py +0 -0
  37. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/directivity.py +0 -0
  38. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/logging_utils.py +0 -0
  39. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/metadata.py +0 -0
  40. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/results.py +0 -0
  41. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir/scene.py +0 -0
  42. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/dependency_links.txt +0 -0
  43. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/requires.txt +0 -0
  44. {torchrir-0.1.2 → torchrir-0.2.0}/src/torchrir.egg-info/top_level.txt +0 -0
  45. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_room.py +0 -0
  46. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_signal.py +0 -0
  47. {torchrir-0.1.2 → torchrir-0.2.0}/tests/test_utils.py +0 -0
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchrir
3
+ Version: 0.2.0
4
+ Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
5
+ Project-URL: Repository, https://github.com/taishi-n/torchrir
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ License-File: NOTICE
10
+ Requires-Dist: numpy>=2.2.6
11
+ Requires-Dist: torch>=2.10.0
12
+ Dynamic: license-file
13
+
14
+ # TorchRIR
15
+
16
+ PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
17
+ This project has been substantially assisted by AI using Codex.
18
+
19
+ ## Installation
20
+ ```bash
21
+ pip install torchrir
22
+ ```
23
+
24
+ ## Examples
25
+ - `examples/static.py`: fixed sources/mics with binaural output.
26
+ `uv run python examples/static.py --plot`
27
+ - `examples/dynamic_src.py`: moving sources, fixed mics.
28
+ `uv run python examples/dynamic_src.py --plot`
29
+ - `examples/dynamic_mic.py`: fixed sources, moving mics.
30
+ `uv run python examples/dynamic_mic.py --plot`
31
+ - `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
32
+ `uv run python examples/cli.py --mode static --plot`
33
+ - `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
34
+ `uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
35
+ - `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
36
+ `uv run python examples/benchmark_device.py --dynamic`
37
+
38
+ ## Core API Overview
39
+ - Geometry: `Room`, `Source`, `MicrophoneArray`
40
+ - Static RIR: `simulate_rir`
41
+ - Dynamic RIR: `simulate_dynamic_rir`
42
+ - Dynamic convolution: `DynamicConvolver`
43
+ - Metadata export: `build_metadata`, `save_metadata_json`
44
+
45
+ ```python
46
+ from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
47
+
48
+ room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
49
+ sources = Source.from_positions([[1.0, 2.0, 1.5]])
50
+ mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
51
+
52
+ rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
53
+ # For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
54
+ # y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
55
+ ```
56
+
57
+ For detailed documentation, see the docs under `docs/` and Read the Docs.
58
+
59
+ ## Future Work
60
+ - Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
61
+ - CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
62
+ - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
63
+ - Add regression tests comparing generated RIRs against gpuRIR outputs.
64
+
65
+ ## Related Libraries
66
+ - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
67
+ - [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
68
+ - [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
69
+ - [das-generator](https://github.com/ehabets/das-generator)
70
+ - [rir-generator](https://github.com/audiolabs/rir-generator)
@@ -0,0 +1,57 @@
1
+ # TorchRIR
2
+
3
+ PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
4
+ This project has been substantially assisted by AI using Codex.
5
+
6
+ ## Installation
7
+ ```bash
8
+ pip install torchrir
9
+ ```
10
+
11
+ ## Examples
12
+ - `examples/static.py`: fixed sources/mics with binaural output.
13
+ `uv run python examples/static.py --plot`
14
+ - `examples/dynamic_src.py`: moving sources, fixed mics.
15
+ `uv run python examples/dynamic_src.py --plot`
16
+ - `examples/dynamic_mic.py`: fixed sources, moving mics.
17
+ `uv run python examples/dynamic_mic.py --plot`
18
+ - `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
19
+ `uv run python examples/cli.py --mode static --plot`
20
+ - `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
21
+ `uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
22
+ - `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
23
+ `uv run python examples/benchmark_device.py --dynamic`
24
+
25
+ ## Core API Overview
26
+ - Geometry: `Room`, `Source`, `MicrophoneArray`
27
+ - Static RIR: `simulate_rir`
28
+ - Dynamic RIR: `simulate_dynamic_rir`
29
+ - Dynamic convolution: `DynamicConvolver`
30
+ - Metadata export: `build_metadata`, `save_metadata_json`
31
+
32
+ ```python
33
+ from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
34
+
35
+ room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
36
+ sources = Source.from_positions([[1.0, 2.0, 1.5]])
37
+ mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
38
+
39
+ rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
40
+ # For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
41
+ # y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
42
+ ```
43
+
44
+ For detailed documentation, see the docs under `docs/` and Read the Docs.
45
+
46
+ ## Future Work
47
+ - Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
48
+ - CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
49
+ - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
50
+ - Add regression tests comparing generated RIRs against gpuRIR outputs.
51
+
52
+ ## Related Libraries
53
+ - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
54
+ - [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
55
+ - [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
56
+ - [das-generator](https://github.com/ehabets/das-generator)
57
+ - [rir-generator](https://github.com/audiolabs/rir-generator)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torchrir"
3
- version = "0.1.2"
3
+ version = "0.2.0"
4
4
  description = "PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -14,8 +14,10 @@ Repository = "https://github.com/taishi-n/torchrir"
14
14
 
15
15
  [dependency-groups]
16
16
  dev = [
17
+ "commitizen>=3.29.0",
17
18
  "git-cliff>=2.10.1",
18
19
  "matplotlib>=3.10.8",
20
+ "ruff>=0.12.2",
19
21
  "pillow>=11.2.1",
20
22
  "pyroomacoustics>=0.9.0",
21
23
  "pytest>=9.0.2",
@@ -23,4 +25,13 @@ dev = [
23
25
  "sphinx>=7.0,<8.2.3",
24
26
  "sphinx-rtd-theme>=2.0.0",
25
27
  "myst-parser>=2.0,<4.0",
28
+ "ty>=0.0.14,<0.1",
26
29
  ]
30
+
31
+ [tool.commitizen]
32
+ name = "cz_conventional_commits"
33
+ tag_format = "v$version"
34
+ version_scheme = "pep440"
35
+ version_provider = "pep621"
36
+ update_changelog_on_bump = true
37
+ changelog_file = "CHANGELOG.md"
@@ -18,6 +18,11 @@ from .datasets import (
18
18
  CmuArcticDataset,
19
19
  CmuArcticSentence,
20
20
  choose_speakers,
21
+ CollateBatch,
22
+ collate_dataset_items,
23
+ DatasetItem,
24
+ LibriSpeechDataset,
25
+ LibriSpeechSentence,
21
26
  list_cmu_arctic_speakers,
22
27
  SentenceLike,
23
28
  load_dataset_sources,
@@ -26,7 +31,12 @@ from .datasets import (
26
31
  load_wav_mono,
27
32
  save_wav,
28
33
  )
29
- from .scene_utils import binaural_mic_positions, clamp_positions, linear_trajectory, sample_positions
34
+ from .scene_utils import (
35
+ binaural_mic_positions,
36
+ clamp_positions,
37
+ linear_trajectory,
38
+ sample_positions,
39
+ )
30
40
  from .utils import (
31
41
  att2t_SabineEstimation,
32
42
  att2t_sabine_estimation,
@@ -56,6 +66,11 @@ __all__ = [
56
66
  "CmuArcticDataset",
57
67
  "CmuArcticSentence",
58
68
  "choose_speakers",
69
+ "CollateBatch",
70
+ "collate_dataset_items",
71
+ "DatasetItem",
72
+ "LibriSpeechDataset",
73
+ "LibriSpeechSentence",
59
74
  "DynamicConvolver",
60
75
  "estimate_beta_from_t60",
61
76
  "estimate_t60_from_beta",
@@ -104,15 +104,15 @@ def animate_scene_gif(
104
104
  mic_lines = []
105
105
  for _ in range(view_src_traj.shape[1]):
106
106
  if view_dim == 2:
107
- line, = ax.plot([], [], color="tab:green", alpha=0.6)
107
+ (line,) = ax.plot([], [], color="tab:green", alpha=0.6)
108
108
  else:
109
- line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
109
+ (line,) = ax.plot([], [], [], color="tab:green", alpha=0.6)
110
110
  src_lines.append(line)
111
111
  for _ in range(view_mic_traj.shape[1]):
112
112
  if view_dim == 2:
113
- line, = ax.plot([], [], color="tab:orange", alpha=0.6)
113
+ (line,) = ax.plot([], [], color="tab:orange", alpha=0.6)
114
114
  else:
115
- line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
115
+ (line,) = ax.plot([], [], [], color="tab:orange", alpha=0.6)
116
116
  mic_lines.append(line)
117
117
 
118
118
  ax.legend(loc="best")
@@ -137,15 +137,15 @@ def animate_scene_gif(
137
137
  xy = mic_frame[:, m_idx, :]
138
138
  line.set_data(xy[:, 0], xy[:, 1])
139
139
  else:
140
- src_scatter._offsets3d = (
141
- src_pos_frame[:, 0],
142
- src_pos_frame[:, 1],
143
- src_pos_frame[:, 2],
140
+ setattr(
141
+ src_scatter,
142
+ "_offsets3d",
143
+ (src_pos_frame[:, 0], src_pos_frame[:, 1], src_pos_frame[:, 2]),
144
144
  )
145
- mic_scatter._offsets3d = (
146
- mic_pos_frame[:, 0],
147
- mic_pos_frame[:, 1],
148
- mic_pos_frame[:, 2],
145
+ setattr(
146
+ mic_scatter,
147
+ "_offsets3d",
148
+ (mic_pos_frame[:, 0], mic_pos_frame[:, 1], mic_pos_frame[:, 2]),
149
149
  )
150
150
  for s_idx, line in enumerate(src_lines):
151
151
  xyz = src_frame[:, s_idx, :]
@@ -166,7 +166,10 @@ def animate_scene_gif(
166
166
  fps = frames / duration_s
167
167
  else:
168
168
  fps = 6.0
169
- anim = animation.FuncAnimation(fig, _frame, frames=frames, interval=1000 / fps, blit=False)
170
- anim.save(out_path, writer="pillow", fps=fps)
169
+ anim = animation.FuncAnimation(
170
+ fig, _frame, frames=frames, interval=1000 / fps, blit=False
171
+ )
172
+ fps_int = None if fps is None else max(1, int(round(fps)))
173
+ anim.save(out_path, writer="pillow", fps=fps_int)
171
174
  plt.close(fig)
172
175
  return out_path
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  """Core RIR simulation functions (static and dynamic)."""
4
4
 
5
5
  import math
6
+ from collections.abc import Callable
6
7
  from typing import Optional, Tuple
7
8
 
8
9
  import torch
@@ -61,8 +62,8 @@ def simulate_rir(
61
62
 
62
63
  Example:
63
64
  >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
64
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
65
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
65
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
66
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
66
67
  >>> rir = simulate_rir(
67
68
  ... room=room,
68
69
  ... sources=sources,
@@ -90,9 +91,9 @@ def simulate_rir(
90
91
 
91
92
  if not isinstance(room, Room):
92
93
  raise TypeError("room must be a Room instance")
93
- if nsample is None and tmax is None:
94
- raise ValueError("nsample or tmax must be provided")
95
94
  if nsample is None:
95
+ if tmax is None:
96
+ raise ValueError("nsample or tmax must be provided")
96
97
  nsample = int(math.ceil(tmax * room.fs))
97
98
  if nsample <= 0:
98
99
  raise ValueError("nsample must be positive")
@@ -261,6 +262,11 @@ def simulate_dynamic_rir(
261
262
 
262
263
  src_traj = as_tensor(src_traj, device=device, dtype=dtype)
263
264
  mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
265
+ device, dtype = infer_device_dtype(
266
+ src_traj, mic_traj, room.size, device=device, dtype=dtype
267
+ )
268
+ src_traj = as_tensor(src_traj, device=device, dtype=dtype)
269
+ mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
264
270
 
265
271
  if src_traj.ndim == 2:
266
272
  src_traj = src_traj.unsqueeze(1)
@@ -273,24 +279,95 @@ def simulate_dynamic_rir(
273
279
  if src_traj.shape[0] != mic_traj.shape[0]:
274
280
  raise ValueError("src_traj and mic_traj must have the same time length")
275
281
 
276
- t_steps = src_traj.shape[0]
277
- rirs = []
278
- for t_idx in range(t_steps):
279
- rir = simulate_rir(
280
- room=room,
281
- sources=src_traj[t_idx],
282
- mics=mic_traj[t_idx],
283
- max_order=max_order,
284
- nsample=nsample,
285
- tmax=tmax,
286
- directivity=directivity,
287
- orientation=orientation,
288
- config=config,
289
- device=device,
290
- dtype=dtype,
282
+ if not isinstance(room, Room):
283
+ raise TypeError("room must be a Room instance")
284
+ if nsample is None:
285
+ if tmax is None:
286
+ raise ValueError("nsample or tmax must be provided")
287
+ nsample = int(math.ceil(tmax * room.fs))
288
+ if nsample <= 0:
289
+ raise ValueError("nsample must be positive")
290
+ if max_order < 0:
291
+ raise ValueError("max_order must be non-negative")
292
+
293
+ room_size = as_tensor(room.size, device=device, dtype=dtype)
294
+ room_size = ensure_dim(room_size)
295
+ dim = room_size.numel()
296
+ if src_traj.shape[2] != dim:
297
+ raise ValueError("src_traj must match room dimension")
298
+ if mic_traj.shape[2] != dim:
299
+ raise ValueError("mic_traj must match room dimension")
300
+
301
+ src_ori = None
302
+ mic_ori = None
303
+ if orientation is not None:
304
+ if isinstance(orientation, (list, tuple)):
305
+ if len(orientation) != 2:
306
+ raise ValueError("orientation tuple must have length 2")
307
+ src_ori, mic_ori = orientation
308
+ else:
309
+ src_ori = orientation
310
+ mic_ori = orientation
311
+ if src_ori is not None:
312
+ src_ori = as_tensor(src_ori, device=device, dtype=dtype)
313
+ if mic_ori is not None:
314
+ mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
315
+
316
+ beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
317
+ beta = _validate_beta(beta, dim)
318
+ n_vec = _image_source_indices(max_order, dim, device=device, nb_img=None)
319
+ refl = _reflection_coefficients(n_vec, beta)
320
+
321
+ src_pattern, mic_pattern = split_directivity(directivity)
322
+ mic_dir = None
323
+ if mic_pattern != "omni":
324
+ if mic_ori is None:
325
+ raise ValueError("mic orientation required for non-omni directivity")
326
+ mic_dir = orientation_to_unit(mic_ori, dim)
327
+
328
+ n_src = src_traj.shape[1]
329
+ n_mic = mic_traj.shape[1]
330
+ rirs = torch.zeros((src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype)
331
+ fdl = cfg.frac_delay_length
332
+ fdl2 = (fdl - 1) // 2
333
+ img_chunk = cfg.image_chunk_size
334
+ if img_chunk <= 0:
335
+ img_chunk = n_vec.shape[0]
336
+
337
+ src_dirs = None
338
+ if src_pattern != "omni":
339
+ if src_ori is None:
340
+ raise ValueError("source orientation required for non-omni directivity")
341
+ src_dirs = orientation_to_unit(src_ori, dim)
342
+ if src_dirs.ndim == 1:
343
+ src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
344
+ if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
345
+ raise ValueError("source orientation must match number of sources")
346
+
347
+ for start in range(0, n_vec.shape[0], img_chunk):
348
+ end = min(start + img_chunk, n_vec.shape[0])
349
+ n_vec_chunk = n_vec[start:end]
350
+ refl_chunk = refl[start:end]
351
+ sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
352
+ src_traj,
353
+ mic_traj,
354
+ room_size,
355
+ n_vec_chunk,
356
+ refl_chunk,
357
+ room,
358
+ fdl2,
359
+ src_pattern=src_pattern,
360
+ mic_pattern=mic_pattern,
361
+ src_dirs=src_dirs,
362
+ mic_dir=mic_dir,
291
363
  )
292
- rirs.append(rir)
293
- return torch.stack(rirs, dim=0)
364
+ t_steps = src_traj.shape[0]
365
+ sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
366
+ attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
367
+ rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
368
+ _accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)
369
+
370
+ return rirs
294
371
 
295
372
 
296
373
  def _prepare_entities(
@@ -495,7 +572,11 @@ def _compute_image_contributions_batch(
495
572
  if mic_pattern != "omni":
496
573
  if mic_dir is None:
497
574
  raise ValueError("mic orientation required for non-omni directivity")
498
- mic_dir = mic_dir[None, :, None, :] if mic_dir.ndim == 2 else mic_dir.view(1, 1, 1, -1)
575
+ mic_dir = (
576
+ mic_dir[None, :, None, :]
577
+ if mic_dir.ndim == 2
578
+ else mic_dir.view(1, 1, 1, -1)
579
+ )
499
580
  cos_theta = _cos_between(-vec, mic_dir)
500
581
  gain = gain * directivity_gain(mic_pattern, cos_theta)
501
582
 
@@ -503,6 +584,54 @@ def _compute_image_contributions_batch(
503
584
  return sample, attenuation
504
585
 
505
586
 
587
+ def _compute_image_contributions_time_batch(
588
+ src_traj: Tensor,
589
+ mic_traj: Tensor,
590
+ room_size: Tensor,
591
+ n_vec: Tensor,
592
+ refl: Tensor,
593
+ room: Room,
594
+ fdl2: int,
595
+ *,
596
+ src_pattern: str,
597
+ mic_pattern: str,
598
+ src_dirs: Optional[Tensor],
599
+ mic_dir: Optional[Tensor],
600
+ ) -> Tuple[Tensor, Tensor]:
601
+ """Compute samples/attenuation for all time steps in batch."""
602
+ sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_traj.dtype)
603
+ n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_traj.dtype)
604
+ base = 2.0 * room_size * n
605
+ img = base[None, None, :, :] + sign[None, None, :, :] * src_traj[:, :, None, :]
606
+ vec = mic_traj[:, None, :, None, :] - img[:, :, None, :, :]
607
+ dist = torch.linalg.norm(vec, dim=-1)
608
+ dist = torch.clamp(dist, min=1e-6)
609
+ time = dist / room.c
610
+ time = time + (fdl2 / room.fs)
611
+ sample = time * room.fs
612
+
613
+ gain = refl.view(1, 1, 1, -1)
614
+ if src_pattern != "omni":
615
+ if src_dirs is None:
616
+ raise ValueError("source orientation required for non-omni directivity")
617
+ src_dirs_b = src_dirs[None, :, None, None, :]
618
+ cos_theta = _cos_between(vec, src_dirs_b)
619
+ gain = gain * directivity_gain(src_pattern, cos_theta)
620
+ if mic_pattern != "omni":
621
+ if mic_dir is None:
622
+ raise ValueError("mic orientation required for non-omni directivity")
623
+ mic_dir_b = (
624
+ mic_dir[None, None, :, None, :]
625
+ if mic_dir.ndim == 2
626
+ else mic_dir.view(1, 1, 1, 1, -1)
627
+ )
628
+ cos_theta = _cos_between(-vec, mic_dir_b)
629
+ gain = gain * directivity_gain(mic_pattern, cos_theta)
630
+
631
+ attenuation = gain / dist
632
+ return sample, attenuation
633
+
634
+
506
635
  def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
507
636
  """Pick the correct orientation vector for a given entity index."""
508
637
  if orientation.ndim == 0:
@@ -542,9 +671,9 @@ def _accumulate_rir(
542
671
  if use_lut:
543
672
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
544
673
 
545
- mic_offsets = (torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample).view(
546
- n_mic, 1, 1
547
- )
674
+ mic_offsets = (
675
+ torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample
676
+ ).view(n_mic, 1, 1)
548
677
  rir_flat = rir.view(-1)
549
678
 
550
679
  chunk_size = cfg.accumulate_chunk_size
@@ -559,7 +688,9 @@ def _accumulate_rir(
559
688
  x_off_frac = (1.0 - frac_m) * lut_gran
560
689
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
561
690
  x_off = x_off_frac - lut_gran_off.to(dtype)
562
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
691
+ lut_pos = lut_gran_off[..., None] + (
692
+ n[None, None, :].to(torch.int64) * lut_gran
693
+ )
563
694
 
564
695
  s0 = torch.take(sinc_lut, lut_pos)
565
696
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -618,9 +749,9 @@ def _accumulate_rir_batch_impl(
618
749
  if use_lut:
619
750
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
620
751
 
621
- sm_offsets = (torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample).view(
622
- n_sm, 1, 1
623
- )
752
+ sm_offsets = (
753
+ torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample
754
+ ).view(n_sm, 1, 1)
624
755
  rir_flat = rir.view(-1)
625
756
 
626
757
  n_img = idx0.shape[1]
@@ -634,7 +765,9 @@ def _accumulate_rir_batch_impl(
634
765
  x_off_frac = (1.0 - frac_m) * lut_gran
635
766
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
636
767
  x_off = x_off_frac - lut_gran_off.to(sample.dtype)
637
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
768
+ lut_pos = lut_gran_off[..., None] + (
769
+ n[None, None, :].to(torch.int64) * lut_gran
770
+ )
638
771
 
639
772
  s0 = torch.take(sinc_lut, lut_pos)
640
773
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -660,12 +793,13 @@ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
660
793
  _FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
661
794
  _FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
662
795
  _FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
663
- _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], callable] = {}
796
+ _AccumFn = Callable[[Tensor, Tensor, Tensor], None]
797
+ _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], _AccumFn] = {}
664
798
 
665
799
 
666
800
  def _get_accumulate_fn(
667
801
  cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
668
- ) -> callable:
802
+ ) -> _AccumFn:
669
803
  """Return an accumulation function with config-bound constants."""
670
804
  use_lut = cfg.use_lut and device.type != "mps"
671
805
  fdl = cfg.frac_delay_length
@@ -721,7 +855,9 @@ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Te
721
855
  return cached
722
856
 
723
857
 
724
- def _get_sinc_lut(fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
858
+ def _get_sinc_lut(
859
+ fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype
860
+ ) -> Tensor:
725
861
  """Create a sinc lookup table for fractional delays."""
726
862
  key = (fdl, lut_gran, str(device), dtype)
727
863
  cached = _SINC_LUT_CACHE.get(key)
@@ -765,7 +901,12 @@ def _apply_diffuse_tail(
765
901
 
766
902
  gen = torch.Generator(device=rir.device)
767
903
  gen.manual_seed(0 if seed is None else seed)
768
- noise = torch.randn(rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen)
769
- scale = torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True) + 1e-8
904
+ noise = torch.randn(
905
+ rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen
906
+ )
907
+ scale = (
908
+ torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True)
909
+ + 1e-8
910
+ )
770
911
  rir[..., tdiff_idx:] = noise * decay * scale
771
912
  return rir
@@ -1,14 +1,15 @@
1
1
  """Dataset helpers for torchrir."""
2
2
 
3
- from .base import BaseDataset, SentenceLike
4
- from .utils import choose_speakers, load_dataset_sources
3
+ from .base import BaseDataset, DatasetItem, SentenceLike
4
+ from .utils import choose_speakers, load_dataset_sources, load_wav_mono
5
+ from .collate import CollateBatch, collate_dataset_items
5
6
  from .template import TemplateDataset, TemplateSentence
7
+ from .librispeech import LibriSpeechDataset, LibriSpeechSentence
6
8
 
7
9
  from .cmu_arctic import (
8
10
  CmuArcticDataset,
9
11
  CmuArcticSentence,
10
12
  list_cmu_arctic_speakers,
11
- load_wav_mono,
12
13
  save_wav,
13
14
  )
14
15
 
@@ -17,6 +18,9 @@ __all__ = [
17
18
  "CmuArcticDataset",
18
19
  "CmuArcticSentence",
19
20
  "choose_speakers",
21
+ "DatasetItem",
22
+ "CollateBatch",
23
+ "collate_dataset_items",
20
24
  "list_cmu_arctic_speakers",
21
25
  "SentenceLike",
22
26
  "load_dataset_sources",
@@ -24,4 +28,6 @@ __all__ = [
24
28
  "save_wav",
25
29
  "TemplateDataset",
26
30
  "TemplateSentence",
31
+ "LibriSpeechDataset",
32
+ "LibriSpeechSentence",
27
33
  ]
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ """Dataset protocol definitions."""
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Protocol, Sequence, Tuple
7
+
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class SentenceLike(Protocol):
13
+ """Minimal sentence interface for dataset entries."""
14
+
15
+ utterance_id: str
16
+ text: str
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class DatasetItem:
21
+ """Dataset item for DataLoader consumption."""
22
+
23
+ audio: torch.Tensor
24
+ sample_rate: int
25
+ utterance_id: str
26
+ text: Optional[str] = None
27
+ speaker: Optional[str] = None
28
+
29
+
30
+ class BaseDataset(Dataset[DatasetItem]):
31
+ """Base dataset class compatible with torch.utils.data.Dataset."""
32
+
33
+ _sentences_cache: Optional[list[SentenceLike]] = None
34
+
35
+ def list_speakers(self) -> list[str]:
36
+ """Return available speaker IDs."""
37
+ raise NotImplementedError
38
+
39
+ def available_sentences(self) -> Sequence[SentenceLike]:
40
+ """Return sentence entries that have audio available."""
41
+ raise NotImplementedError
42
+
43
+ def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
44
+ """Load audio for an utterance and return (audio, sample_rate)."""
45
+ raise NotImplementedError
46
+
47
+ def __len__(self) -> int:
48
+ return len(self._get_sentences())
49
+
50
+ def __getitem__(self, idx: int) -> DatasetItem:
51
+ sentences = self._get_sentences()
52
+ sentence = sentences[idx]
53
+ audio, sample_rate = self.load_wav(sentence.utterance_id)
54
+ speaker = getattr(self, "speaker", None)
55
+ text = getattr(sentence, "text", None)
56
+ return DatasetItem(
57
+ audio=audio,
58
+ sample_rate=sample_rate,
59
+ utterance_id=sentence.utterance_id,
60
+ text=text,
61
+ speaker=speaker,
62
+ )
63
+
64
+ def _get_sentences(self) -> list[SentenceLike]:
65
+ if self._sentences_cache is None:
66
+ self._sentences_cache = list(self.available_sentences())
67
+ return self._sentences_cache