torchrir 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. torchrir-0.1.4/PKG-INFO +70 -0
  2. torchrir-0.1.4/README.md +57 -0
  3. {torchrir-0.1.2 → torchrir-0.1.4}/pyproject.toml +12 -1
  4. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/__init__.py +6 -1
  5. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/animation.py +17 -14
  6. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/core.py +35 -18
  7. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/datasets/cmu_arctic.py +4 -1
  8. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/datasets/template.py +3 -1
  9. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/datasets/utils.py +5 -1
  10. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/dynamic.py +3 -1
  11. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/plotting.py +13 -6
  12. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/plotting_utils.py +4 -1
  13. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/room.py +2 -38
  14. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/scene_utils.py +6 -2
  15. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/signal.py +24 -10
  16. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/simulators.py +12 -4
  17. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/utils.py +1 -1
  18. torchrir-0.1.4/src/torchrir.egg-info/PKG-INFO +70 -0
  19. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_compare_pyroomacoustics.py +5 -3
  20. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_core.py +6 -6
  21. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_device_parity.py +4 -4
  22. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_plotting.py +1 -0
  23. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_scene.py +25 -13
  24. torchrir-0.1.2/PKG-INFO +0 -271
  25. torchrir-0.1.2/README.md +0 -258
  26. torchrir-0.1.2/src/torchrir.egg-info/PKG-INFO +0 -271
  27. {torchrir-0.1.2 → torchrir-0.1.4}/LICENSE +0 -0
  28. {torchrir-0.1.2 → torchrir-0.1.4}/NOTICE +0 -0
  29. {torchrir-0.1.2 → torchrir-0.1.4}/setup.cfg +0 -0
  30. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/config.py +0 -0
  31. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/datasets/__init__.py +0 -0
  32. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/datasets/base.py +0 -0
  33. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/directivity.py +0 -0
  34. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/logging_utils.py +0 -0
  35. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/metadata.py +0 -0
  36. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/results.py +0 -0
  37. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir/scene.py +0 -0
  38. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir.egg-info/SOURCES.txt +0 -0
  39. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir.egg-info/dependency_links.txt +0 -0
  40. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir.egg-info/requires.txt +0 -0
  41. {torchrir-0.1.2 → torchrir-0.1.4}/src/torchrir.egg-info/top_level.txt +0 -0
  42. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_room.py +0 -0
  43. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_signal.py +0 -0
  44. {torchrir-0.1.2 → torchrir-0.1.4}/tests/test_utils.py +0 -0
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchrir
3
+ Version: 0.1.4
4
+ Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
5
+ Project-URL: Repository, https://github.com/taishi-n/torchrir
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ License-File: NOTICE
10
+ Requires-Dist: numpy>=2.2.6
11
+ Requires-Dist: torch>=2.10.0
12
+ Dynamic: license-file
13
+
14
+ # TorchRIR
15
+
16
+ PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
17
+ This project has been substantially assisted by AI using Codex.
18
+
19
+ ## Installation
20
+ ```bash
21
+ pip install torchrir
22
+ ```
23
+
24
+ ## Examples
25
+ - `examples/static.py`: fixed sources/mics with binaural output.
26
+ `uv run python examples/static.py --plot`
27
+ - `examples/dynamic_src.py`: moving sources, fixed mics.
28
+ `uv run python examples/dynamic_src.py --plot`
29
+ - `examples/dynamic_mic.py`: fixed sources, moving mics.
30
+ `uv run python examples/dynamic_mic.py --plot`
31
+ - `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
32
+ `uv run python examples/cli.py --mode static --plot`
33
+ - `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
34
+ `uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
35
+ - `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
36
+ `uv run python examples/benchmark_device.py --dynamic`
37
+
38
+ ## Core API Overview
39
+ - Geometry: `Room`, `Source`, `MicrophoneArray`
40
+ - Static RIR: `simulate_rir`
41
+ - Dynamic RIR: `simulate_dynamic_rir`
42
+ - Dynamic convolution: `DynamicConvolver`
43
+ - Metadata export: `build_metadata`, `save_metadata_json`
44
+
45
+ ```python
46
+ from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
47
+
48
+ room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
49
+ sources = Source.from_positions([[1.0, 2.0, 1.5]])
50
+ mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
51
+
52
+ rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
53
+ # For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
54
+ # y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
55
+ ```
56
+
57
+ For detailed documentation, see the docs under `docs/` and Read the Docs.
58
+
59
+ ## Future Work
60
+ - Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
61
+ - CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
62
+ - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`).
63
+ - Add regression tests comparing generated RIRs against gpuRIR outputs.
64
+
65
+ ## Related Libraries
66
+ - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
67
+ - [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
68
+ - [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
69
+ - [das-generator](https://github.com/ehabets/das-generator)
70
+ - [rir-generator](https://github.com/audiolabs/rir-generator)
@@ -0,0 +1,57 @@
1
+ # TorchRIR
2
+
3
+ PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
4
+ This project has been substantially assisted by AI using Codex.
5
+
6
+ ## Installation
7
+ ```bash
8
+ pip install torchrir
9
+ ```
10
+
11
+ ## Examples
12
+ - `examples/static.py`: fixed sources/mics with binaural output.
13
+ `uv run python examples/static.py --plot`
14
+ - `examples/dynamic_src.py`: moving sources, fixed mics.
15
+ `uv run python examples/dynamic_src.py --plot`
16
+ - `examples/dynamic_mic.py`: fixed sources, moving mics.
17
+ `uv run python examples/dynamic_mic.py --plot`
18
+ - `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
19
+ `uv run python examples/cli.py --mode static --plot`
20
+ - `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
21
+ `uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
22
+ - `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
23
+ `uv run python examples/benchmark_device.py --dynamic`
24
+
25
+ ## Core API Overview
26
+ - Geometry: `Room`, `Source`, `MicrophoneArray`
27
+ - Static RIR: `simulate_rir`
28
+ - Dynamic RIR: `simulate_dynamic_rir`
29
+ - Dynamic convolution: `DynamicConvolver`
30
+ - Metadata export: `build_metadata`, `save_metadata_json`
31
+
32
+ ```python
33
+ from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
34
+
35
+ room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
36
+ sources = Source.from_positions([[1.0, 2.0, 1.5]])
37
+ mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
38
+
39
+ rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
40
+ # For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
41
+ # y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
42
+ ```
43
+
44
+ For detailed documentation, see the docs under `docs/` and Read the Docs.
45
+
46
+ ## Future Work
47
+ - Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
48
+ - CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
49
+ - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`).
50
+ - Add regression tests comparing generated RIRs against gpuRIR outputs.
51
+
52
+ ## Related Libraries
53
+ - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
54
+ - [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
55
+ - [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
56
+ - [das-generator](https://github.com/ehabets/das-generator)
57
+ - [rir-generator](https://github.com/audiolabs/rir-generator)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torchrir"
3
- version = "0.1.2"
3
+ version = "0.1.4"
4
4
  description = "PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -14,8 +14,10 @@ Repository = "https://github.com/taishi-n/torchrir"
14
14
 
15
15
  [dependency-groups]
16
16
  dev = [
17
+ "commitizen>=3.29.0",
17
18
  "git-cliff>=2.10.1",
18
19
  "matplotlib>=3.10.8",
20
+ "ruff>=0.12.2",
19
21
  "pillow>=11.2.1",
20
22
  "pyroomacoustics>=0.9.0",
21
23
  "pytest>=9.0.2",
@@ -23,4 +25,13 @@ dev = [
23
25
  "sphinx>=7.0,<8.2.3",
24
26
  "sphinx-rtd-theme>=2.0.0",
25
27
  "myst-parser>=2.0,<4.0",
28
+ "ty>=0.0.14,<0.1",
26
29
  ]
30
+
31
+ [tool.commitizen]
32
+ name = "cz_conventional_commits"
33
+ tag_format = "v$version"
34
+ version_scheme = "pep440"
35
+ version_provider = "pep621"
36
+ update_changelog_on_bump = true
37
+ changelog_file = "CHANGELOG.md"
@@ -26,7 +26,12 @@ from .datasets import (
26
26
  load_wav_mono,
27
27
  save_wav,
28
28
  )
29
- from .scene_utils import binaural_mic_positions, clamp_positions, linear_trajectory, sample_positions
29
+ from .scene_utils import (
30
+ binaural_mic_positions,
31
+ clamp_positions,
32
+ linear_trajectory,
33
+ sample_positions,
34
+ )
30
35
  from .utils import (
31
36
  att2t_SabineEstimation,
32
37
  att2t_sabine_estimation,
@@ -104,15 +104,15 @@ def animate_scene_gif(
104
104
  mic_lines = []
105
105
  for _ in range(view_src_traj.shape[1]):
106
106
  if view_dim == 2:
107
- line, = ax.plot([], [], color="tab:green", alpha=0.6)
107
+ (line,) = ax.plot([], [], color="tab:green", alpha=0.6)
108
108
  else:
109
- line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
109
+ (line,) = ax.plot([], [], [], color="tab:green", alpha=0.6)
110
110
  src_lines.append(line)
111
111
  for _ in range(view_mic_traj.shape[1]):
112
112
  if view_dim == 2:
113
- line, = ax.plot([], [], color="tab:orange", alpha=0.6)
113
+ (line,) = ax.plot([], [], color="tab:orange", alpha=0.6)
114
114
  else:
115
- line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
115
+ (line,) = ax.plot([], [], [], color="tab:orange", alpha=0.6)
116
116
  mic_lines.append(line)
117
117
 
118
118
  ax.legend(loc="best")
@@ -137,15 +137,15 @@ def animate_scene_gif(
137
137
  xy = mic_frame[:, m_idx, :]
138
138
  line.set_data(xy[:, 0], xy[:, 1])
139
139
  else:
140
- src_scatter._offsets3d = (
141
- src_pos_frame[:, 0],
142
- src_pos_frame[:, 1],
143
- src_pos_frame[:, 2],
140
+ setattr(
141
+ src_scatter,
142
+ "_offsets3d",
143
+ (src_pos_frame[:, 0], src_pos_frame[:, 1], src_pos_frame[:, 2]),
144
144
  )
145
- mic_scatter._offsets3d = (
146
- mic_pos_frame[:, 0],
147
- mic_pos_frame[:, 1],
148
- mic_pos_frame[:, 2],
145
+ setattr(
146
+ mic_scatter,
147
+ "_offsets3d",
148
+ (mic_pos_frame[:, 0], mic_pos_frame[:, 1], mic_pos_frame[:, 2]),
149
149
  )
150
150
  for s_idx, line in enumerate(src_lines):
151
151
  xyz = src_frame[:, s_idx, :]
@@ -166,7 +166,10 @@ def animate_scene_gif(
166
166
  fps = frames / duration_s
167
167
  else:
168
168
  fps = 6.0
169
- anim = animation.FuncAnimation(fig, _frame, frames=frames, interval=1000 / fps, blit=False)
170
- anim.save(out_path, writer="pillow", fps=fps)
169
+ anim = animation.FuncAnimation(
170
+ fig, _frame, frames=frames, interval=1000 / fps, blit=False
171
+ )
172
+ fps_int = None if fps is None else max(1, int(round(fps)))
173
+ anim.save(out_path, writer="pillow", fps=fps_int)
171
174
  plt.close(fig)
172
175
  return out_path
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  """Core RIR simulation functions (static and dynamic)."""
4
4
 
5
5
  import math
6
+ from collections.abc import Callable
6
7
  from typing import Optional, Tuple
7
8
 
8
9
  import torch
@@ -61,8 +62,8 @@ def simulate_rir(
61
62
 
62
63
  Example:
63
64
  >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
64
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
65
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
65
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
66
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
66
67
  >>> rir = simulate_rir(
67
68
  ... room=room,
68
69
  ... sources=sources,
@@ -90,9 +91,9 @@ def simulate_rir(
90
91
 
91
92
  if not isinstance(room, Room):
92
93
  raise TypeError("room must be a Room instance")
93
- if nsample is None and tmax is None:
94
- raise ValueError("nsample or tmax must be provided")
95
94
  if nsample is None:
95
+ if tmax is None:
96
+ raise ValueError("nsample or tmax must be provided")
96
97
  nsample = int(math.ceil(tmax * room.fs))
97
98
  if nsample <= 0:
98
99
  raise ValueError("nsample must be positive")
@@ -495,7 +496,11 @@ def _compute_image_contributions_batch(
495
496
  if mic_pattern != "omni":
496
497
  if mic_dir is None:
497
498
  raise ValueError("mic orientation required for non-omni directivity")
498
- mic_dir = mic_dir[None, :, None, :] if mic_dir.ndim == 2 else mic_dir.view(1, 1, 1, -1)
499
+ mic_dir = (
500
+ mic_dir[None, :, None, :]
501
+ if mic_dir.ndim == 2
502
+ else mic_dir.view(1, 1, 1, -1)
503
+ )
499
504
  cos_theta = _cos_between(-vec, mic_dir)
500
505
  gain = gain * directivity_gain(mic_pattern, cos_theta)
501
506
 
@@ -542,9 +547,9 @@ def _accumulate_rir(
542
547
  if use_lut:
543
548
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
544
549
 
545
- mic_offsets = (torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample).view(
546
- n_mic, 1, 1
547
- )
550
+ mic_offsets = (
551
+ torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample
552
+ ).view(n_mic, 1, 1)
548
553
  rir_flat = rir.view(-1)
549
554
 
550
555
  chunk_size = cfg.accumulate_chunk_size
@@ -559,7 +564,9 @@ def _accumulate_rir(
559
564
  x_off_frac = (1.0 - frac_m) * lut_gran
560
565
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
561
566
  x_off = x_off_frac - lut_gran_off.to(dtype)
562
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
567
+ lut_pos = lut_gran_off[..., None] + (
568
+ n[None, None, :].to(torch.int64) * lut_gran
569
+ )
563
570
 
564
571
  s0 = torch.take(sinc_lut, lut_pos)
565
572
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -618,9 +625,9 @@ def _accumulate_rir_batch_impl(
618
625
  if use_lut:
619
626
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
620
627
 
621
- sm_offsets = (torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample).view(
622
- n_sm, 1, 1
623
- )
628
+ sm_offsets = (
629
+ torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample
630
+ ).view(n_sm, 1, 1)
624
631
  rir_flat = rir.view(-1)
625
632
 
626
633
  n_img = idx0.shape[1]
@@ -634,7 +641,9 @@ def _accumulate_rir_batch_impl(
634
641
  x_off_frac = (1.0 - frac_m) * lut_gran
635
642
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
636
643
  x_off = x_off_frac - lut_gran_off.to(sample.dtype)
637
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
644
+ lut_pos = lut_gran_off[..., None] + (
645
+ n[None, None, :].to(torch.int64) * lut_gran
646
+ )
638
647
 
639
648
  s0 = torch.take(sinc_lut, lut_pos)
640
649
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -660,12 +669,13 @@ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
660
669
  _FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
661
670
  _FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
662
671
  _FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
663
- _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], callable] = {}
672
+ _AccumFn = Callable[[Tensor, Tensor, Tensor], None]
673
+ _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], _AccumFn] = {}
664
674
 
665
675
 
666
676
  def _get_accumulate_fn(
667
677
  cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
668
- ) -> callable:
678
+ ) -> _AccumFn:
669
679
  """Return an accumulation function with config-bound constants."""
670
680
  use_lut = cfg.use_lut and device.type != "mps"
671
681
  fdl = cfg.frac_delay_length
@@ -721,7 +731,9 @@ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Te
721
731
  return cached
722
732
 
723
733
 
724
- def _get_sinc_lut(fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
734
+ def _get_sinc_lut(
735
+ fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype
736
+ ) -> Tensor:
725
737
  """Create a sinc lookup table for fractional delays."""
726
738
  key = (fdl, lut_gran, str(device), dtype)
727
739
  cached = _SINC_LUT_CACHE.get(key)
@@ -765,7 +777,12 @@ def _apply_diffuse_tail(
765
777
 
766
778
  gen = torch.Generator(device=rir.device)
767
779
  gen.manual_seed(0 if seed is None else seed)
768
- noise = torch.randn(rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen)
769
- scale = torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True) + 1e-8
780
+ noise = torch.randn(
781
+ rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen
782
+ )
783
+ scale = (
784
+ torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True)
785
+ + 1e-8
786
+ )
770
787
  rir[..., tdiff_idx:] = noise * decay * scale
771
788
  return rir
@@ -44,6 +44,7 @@ def list_cmu_arctic_speakers() -> List[str]:
44
44
  @dataclass
45
45
  class CmuArcticSentence:
46
46
  """Sentence metadata from CMU ARCTIC."""
47
+
47
48
  utterance_id: str
48
49
  text: str
49
50
 
@@ -56,7 +57,9 @@ class CmuArcticDataset:
56
57
  >>> audio, fs = dataset.load_wav("arctic_a0001")
57
58
  """
58
59
 
59
- def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
60
+ def __init__(
61
+ self, root: Path, speaker: str = "bdl", download: bool = False
62
+ ) -> None:
60
63
  """Initialize a CMU ARCTIC dataset handle.
61
64
 
62
65
  Args:
@@ -34,7 +34,9 @@ class TemplateDataset(BaseDataset):
34
34
  protocol intact.
35
35
  """
36
36
 
37
- def __init__(self, root: Path, speaker: str = "default", download: bool = False) -> None:
37
+ def __init__(
38
+ self, root: Path, speaker: str = "default", download: bool = False
39
+ ) -> None:
38
40
  self.root = Path(root)
39
41
  self.speaker = speaker
40
42
  if download:
@@ -10,7 +10,9 @@ import torch
10
10
  from .base import BaseDataset, SentenceLike
11
11
 
12
12
 
13
- def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
13
+ def choose_speakers(
14
+ dataset: BaseDataset, num_sources: int, rng: random.Random
15
+ ) -> List[str]:
14
16
  """Select unique speakers for the requested number of sources.
15
17
 
16
18
  Example:
@@ -89,4 +91,6 @@ def load_dataset_sources(
89
91
  info.append((speaker, utterance_ids))
90
92
 
91
93
  stacked = torch.stack(signals, dim=0)
94
+ if fs is None:
95
+ raise RuntimeError("no audio loaded from dataset sources")
92
96
  return stacked, int(fs), info
@@ -44,7 +44,9 @@ class DynamicConvolver:
44
44
  if self.hop is None:
45
45
  raise ValueError("hop must be provided for hop mode")
46
46
  return _convolve_dynamic_hop(signal, rirs, self.hop)
47
- return _convolve_dynamic_trajectory(signal, rirs, timestamps=self.timestamps, fs=self.fs)
47
+ return _convolve_dynamic_trajectory(
48
+ signal, rirs, timestamps=self.timestamps, fs=self.fs
49
+ )
48
50
 
49
51
 
50
52
  def _convolve_dynamic_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
@@ -92,7 +92,9 @@ def plot_scene_dynamic(
92
92
  return ax
93
93
 
94
94
 
95
- def _setup_axes(ax: Any | None, room: Room | Sequence[float] | Tensor) -> tuple[Any, Any]:
95
+ def _setup_axes(
96
+ ax: Any | None, room: Room | Sequence[float] | Tensor
97
+ ) -> tuple[Any, Any]:
96
98
  """Create 2D/3D axes based on room dimension."""
97
99
  import matplotlib.pyplot as plt
98
100
 
@@ -131,8 +133,9 @@ def _draw_room_2d(ax: Any, size: Tensor) -> None:
131
133
  """Draw a 2D rectangular room."""
132
134
  import matplotlib.patches as patches
133
135
 
134
- rect = patches.Rectangle((0.0, 0.0), size[0].item(), size[1].item(),
135
- fill=False, edgecolor="black")
136
+ rect = patches.Rectangle(
137
+ (0.0, 0.0), size[0].item(), size[1].item(), fill=False, edgecolor="black"
138
+ )
136
139
  ax.add_patch(rect)
137
140
  ax.set_xlim(0, size[0].item())
138
141
  ax.set_ylim(0, size[1].item())
@@ -186,7 +189,9 @@ def _draw_room_3d(ax: Any, size: Tensor) -> None:
186
189
  ax.set_zlabel("z")
187
190
 
188
191
 
189
- def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None) -> Tensor:
192
+ def _extract_positions(
193
+ entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None
194
+ ) -> Tensor:
190
195
  """Extract positions from Source/MicrophoneArray or raw tensor."""
191
196
  if isinstance(entity, (Source, MicrophoneArray)):
192
197
  pos = entity.positions
@@ -211,7 +216,9 @@ def _scatter_positions(
211
216
  return
212
217
  dim = positions.shape[1]
213
218
  if dim == 2:
214
- ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker, color=color)
219
+ ax.scatter(
220
+ positions[:, 0], positions[:, 1], label=label, marker=marker, color=color
221
+ )
215
222
  else:
216
223
  ax.scatter(
217
224
  positions[:, 0],
@@ -300,4 +307,4 @@ def _is_moving(traj: Tensor, positions: Tensor, *, tol: float = 1e-6) -> bool:
300
307
  if traj.numel() == 0:
301
308
  return False
302
309
  pos0 = positions.unsqueeze(0).expand_as(traj)
303
- return torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item()
310
+ return bool(torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item())
@@ -127,7 +127,10 @@ def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
127
127
  return pos
128
128
 
129
129
 
130
- def _traj_steps(src_traj: Optional[torch.Tensor | Sequence], mic_traj: Optional[torch.Tensor | Sequence]) -> int:
130
+ def _traj_steps(
131
+ src_traj: Optional[torch.Tensor | Sequence],
132
+ mic_traj: Optional[torch.Tensor | Sequence],
133
+ ) -> int:
131
134
  """Infer the number of trajectory steps."""
132
135
  if src_traj is not None:
133
136
  return int(_to_cpu(src_traj).shape[0])
@@ -65,7 +65,7 @@ class Source:
65
65
  """Source container with positions and optional orientation.
66
66
 
67
67
  Example:
68
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
68
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
69
69
  """
70
70
 
71
71
  positions: Tensor
@@ -82,24 +82,6 @@ class Source:
82
82
  """Return a new Source with updated fields."""
83
83
  return replace(self, **kwargs)
84
84
 
85
- @classmethod
86
- def positions(
87
- cls,
88
- positions: Sequence[Sequence[float]] | Tensor,
89
- *,
90
- orientation: Optional[Sequence[float] | Tensor] = None,
91
- device: Optional[torch.device | str] = None,
92
- dtype: Optional[torch.dtype] = None,
93
- ) -> "Source":
94
- """Construct a Source from positions.
95
-
96
- Example:
97
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
98
- """
99
- return cls.from_positions(
100
- positions, orientation=orientation, device=device, dtype=dtype
101
- )
102
-
103
85
  @classmethod
104
86
  def from_positions(
105
87
  cls,
@@ -122,7 +104,7 @@ class MicrophoneArray:
122
104
  """Microphone array container.
123
105
 
124
106
  Example:
125
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
107
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
126
108
  """
127
109
 
128
110
  positions: Tensor
@@ -139,24 +121,6 @@ class MicrophoneArray:
139
121
  """Return a new MicrophoneArray with updated fields."""
140
122
  return replace(self, **kwargs)
141
123
 
142
- @classmethod
143
- def positions(
144
- cls,
145
- positions: Sequence[Sequence[float]] | Tensor,
146
- *,
147
- orientation: Optional[Sequence[float] | Tensor] = None,
148
- device: Optional[torch.device | str] = None,
149
- dtype: Optional[torch.dtype] = None,
150
- ) -> "MicrophoneArray":
151
- """Construct a MicrophoneArray from positions.
152
-
153
- Example:
154
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
155
- """
156
- return cls.from_positions(
157
- positions, orientation=orientation, device=device, dtype=dtype
158
- )
159
-
160
124
  @classmethod
161
125
  def from_positions(
162
126
  cls,
@@ -32,7 +32,9 @@ def sample_positions(
32
32
  return torch.tensor(coords, dtype=torch.float32)
33
33
 
34
34
 
35
- def linear_trajectory(start: torch.Tensor, end: torch.Tensor, steps: int) -> torch.Tensor:
35
+ def linear_trajectory(
36
+ start: torch.Tensor, end: torch.Tensor, steps: int
37
+ ) -> torch.Tensor:
36
38
  """Create a linear trajectory between start and end.
37
39
 
38
40
  Example:
@@ -58,7 +60,9 @@ def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.
58
60
  return torch.stack([left, right], dim=0)
59
61
 
60
62
 
61
- def clamp_positions(positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1) -> torch.Tensor:
63
+ def clamp_positions(
64
+ positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1
65
+ ) -> torch.Tensor:
62
66
  """Clamp positions to remain inside the room with a margin.
63
67
 
64
68
  Example:
@@ -117,9 +117,9 @@ def _convolve_dynamic_rir_trajectory(
117
117
  else:
118
118
  step_fs = n_samples / t_steps
119
119
  ts_dtype = torch.float32 if signal.device.type == "mps" else torch.float64
120
- w_ini = (torch.arange(t_steps, device=signal.device, dtype=ts_dtype) * step_fs).to(
121
- torch.long
122
- )
120
+ w_ini = (
121
+ torch.arange(t_steps, device=signal.device, dtype=ts_dtype) * step_fs
122
+ ).to(torch.long)
123
123
 
124
124
  w_ini = torch.cat(
125
125
  [w_ini, torch.tensor([n_samples], device=signal.device, dtype=torch.long)]
@@ -132,14 +132,18 @@ def _convolve_dynamic_rir_trajectory(
132
132
  )
133
133
 
134
134
  max_len = int(w_len.max().item())
135
- segments = torch.zeros((t_steps, n_src, max_len), dtype=signal.dtype, device=signal.device)
135
+ segments = torch.zeros(
136
+ (t_steps, n_src, max_len), dtype=signal.dtype, device=signal.device
137
+ )
136
138
  for t in range(t_steps):
137
139
  start = int(w_ini[t].item())
138
140
  end = int(w_ini[t + 1].item())
139
141
  if end > start:
140
142
  segments[t, :, : end - start] = signal[:, start:end]
141
143
 
142
- out = torch.zeros((n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device)
144
+ out = torch.zeros(
145
+ (n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device
146
+ )
143
147
 
144
148
  for t in range(t_steps):
145
149
  seg_len = int(w_len[t].item())
@@ -166,7 +170,9 @@ def _convolve_dynamic_rir_trajectory_batched(
166
170
  """GPU-friendly batched trajectory convolution using FFT."""
167
171
  n_samples = signal.shape[1]
168
172
  t_steps, n_src, n_mic, rir_len = rirs.shape
169
- out = torch.zeros((n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device)
173
+ out = torch.zeros(
174
+ (n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device
175
+ )
170
176
 
171
177
  for t0 in range(0, t_steps, chunk_size):
172
178
  t1 = min(t0 + chunk_size, t_steps)
@@ -174,7 +180,9 @@ def _convolve_dynamic_rir_trajectory_batched(
174
180
  max_len = int(lengths.max().item())
175
181
  if max_len == 0:
176
182
  continue
177
- segments = torch.zeros((t1 - t0, n_src, max_len), dtype=signal.dtype, device=signal.device)
183
+ segments = torch.zeros(
184
+ (t1 - t0, n_src, max_len), dtype=signal.dtype, device=signal.device
185
+ )
178
186
  for idx, t in enumerate(range(t0, t1)):
179
187
  start = int(w_ini[t].item())
180
188
  end = int(w_ini[t + 1].item())
@@ -190,7 +198,9 @@ def _convolve_dynamic_rir_trajectory_batched(
190
198
  dtype=signal.dtype,
191
199
  device=signal.device,
192
200
  )
193
- conv = torch.fft.irfft(seg_f[:, :, None, :] * rir_f, n=fft_len, dim=-1, out=conv_out)
201
+ conv = torch.fft.irfft(
202
+ seg_f[:, :, None, :] * rir_f, n=fft_len, dim=-1, out=conv_out
203
+ )
194
204
  conv = conv[..., :conv_len]
195
205
  conv_sum = conv.sum(dim=1)
196
206
 
@@ -199,7 +209,9 @@ def _convolve_dynamic_rir_trajectory_batched(
199
209
  if seg_len == 0:
200
210
  continue
201
211
  start = int(w_ini[t].item())
202
- out[:, start : start + seg_len + rir_len - 1] += conv_sum[idx, :, : seg_len + rir_len - 1]
212
+ out[:, start : start + seg_len + rir_len - 1] += conv_sum[
213
+ idx, :, : seg_len + rir_len - 1
214
+ ]
203
215
 
204
216
  return out.squeeze(0) if n_mic == 1 else out
205
217
 
@@ -221,7 +233,9 @@ def _ensure_static_rirs(rirs: Tensor) -> Tensor:
221
233
  return rirs.view(1, rirs.shape[0], rirs.shape[1])
222
234
  if rirs.ndim == 3:
223
235
  return rirs
224
- raise ValueError("rirs must have shape (rir_len,), (n_mic, rir_len), or (n_src, n_mic, rir_len)")
236
+ raise ValueError(
237
+ "rirs must have shape (rir_len,), (n_mic, rir_len), or (n_src, n_mic, rir_len)"
238
+ )
225
239
 
226
240
 
227
241
  def _ensure_dynamic_rirs(rirs: Tensor, signal: Tensor) -> Tensor: