torchrir 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ """LibriSpeech dataset helpers."""
4
+
5
+ import logging
6
+ import tarfile
7
+ import urllib.request
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import List, Tuple
11
+
12
+ import torch
13
+
14
+ from .base import BaseDataset
15
+ from .utils import load_wav_mono
16
+
17
+ BASE_URL = "https://www.openslr.org/resources/12"
18
+ VALID_SUBSETS = {
19
+ "dev-clean",
20
+ "dev-other",
21
+ "test-clean",
22
+ "test-other",
23
+ "train-clean-100",
24
+ "train-clean-360",
25
+ "train-other-500",
26
+ }
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class LibriSpeechSentence:
33
+ """Sentence metadata from LibriSpeech."""
34
+
35
+ utterance_id: str
36
+ text: str
37
+ speaker_id: str
38
+ chapter_id: str
39
+
40
+
41
+ class LibriSpeechDataset(BaseDataset):
42
+ """LibriSpeech dataset loader.
43
+
44
+ Example:
45
+ >>> dataset = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100", download=True)
46
+ >>> audio, fs = dataset.load_wav("103-1240-0000")
47
+ """
48
+
49
+ def __init__(
50
+ self, root: Path, subset: str = "train-clean-100", download: bool = False
51
+ ) -> None:
52
+ """Initialize a LibriSpeech dataset handle.
53
+
54
+ Args:
55
+ root: Root directory where the dataset is stored.
56
+ subset: LibriSpeech subset name (e.g., "train-clean-100").
57
+ download: Download and extract if missing.
58
+ """
59
+ if subset not in VALID_SUBSETS:
60
+ raise ValueError(f"unsupported subset: {subset}")
61
+ self.root = Path(root)
62
+ self.subset = subset
63
+ self._archive_name = f"{subset}.tar.gz"
64
+ self._base_dir = self.root / "LibriSpeech"
65
+ self._subset_dir = self._base_dir / subset
66
+
67
+ if download:
68
+ self._download_and_extract()
69
+
70
+ if not self._subset_dir.exists():
71
+ raise FileNotFoundError(
72
+ "dataset not found; run with download=True or place the archive under "
73
+ f"{self.root}"
74
+ )
75
+
76
+ def list_speakers(self) -> List[str]:
77
+ """Return available speaker IDs."""
78
+ if not self._subset_dir.exists():
79
+ return []
80
+ return sorted([p.name for p in self._subset_dir.iterdir() if p.is_dir()])
81
+
82
+ def available_sentences(self) -> List[LibriSpeechSentence]:
83
+ """Return sentences that have a corresponding audio file."""
84
+ sentences: List[LibriSpeechSentence] = []
85
+ for trans_path in self._subset_dir.rglob("*.trans.txt"):
86
+ chapter_dir = trans_path.parent
87
+ speaker_id = chapter_dir.parent.name
88
+ chapter_id = chapter_dir.name
89
+ with trans_path.open("r", encoding="utf-8") as f:
90
+ for line in f:
91
+ line = line.strip()
92
+ if not line:
93
+ continue
94
+ utt_id, text = _parse_text_line(line)
95
+ wav_path = chapter_dir / f"{utt_id}.flac"
96
+ if wav_path.exists():
97
+ sentences.append(
98
+ LibriSpeechSentence(
99
+ utterance_id=utt_id,
100
+ text=text,
101
+ speaker_id=speaker_id,
102
+ chapter_id=chapter_id,
103
+ )
104
+ )
105
+ return sentences
106
+
107
+ def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
108
+ """Load a mono wav for the given utterance ID."""
109
+ speaker_id, chapter_id, _ = utterance_id.split("-", 2)
110
+ path = self._subset_dir / speaker_id / chapter_id / f"{utterance_id}.flac"
111
+ return load_wav_mono(path)
112
+
113
+ def _download_and_extract(self) -> None:
114
+ """Download and extract the subset archive if needed."""
115
+ self.root.mkdir(parents=True, exist_ok=True)
116
+ archive_path = self.root / self._archive_name
117
+ url = f"{BASE_URL}/{self._archive_name}"
118
+
119
+ if not archive_path.exists():
120
+ logger.info("Downloading %s", url)
121
+ _download(url, archive_path)
122
+ if not self._subset_dir.exists():
123
+ logger.info("Extracting %s", archive_path)
124
+ try:
125
+ with tarfile.open(archive_path, "r:gz") as tar:
126
+ tar.extractall(self.root)
127
+ except (tarfile.ReadError, EOFError, OSError) as exc:
128
+ logger.warning("Extraction failed (%s); re-downloading.", exc)
129
+ if archive_path.exists():
130
+ archive_path.unlink()
131
+ _download(url, archive_path)
132
+ with tarfile.open(archive_path, "r:gz") as tar:
133
+ tar.extractall(self.root)
134
+
135
+
136
+ def _download(url: str, dest: Path, retries: int = 1) -> None:
137
+ """Download a file with retry and resume-safe temp file."""
138
+ for attempt in range(retries + 1):
139
+ try:
140
+ _stream_download(url, dest)
141
+ return
142
+ except Exception as exc:
143
+ if dest.exists():
144
+ dest.unlink()
145
+ if attempt >= retries:
146
+ raise
147
+ logger.warning("Download failed (%s); retrying...", exc)
148
+
149
+
150
+ def _stream_download(url: str, dest: Path) -> None:
151
+ """Stream a URL to disk with a progress indicator."""
152
+ tmp_path = dest.with_suffix(dest.suffix + ".part")
153
+ if tmp_path.exists():
154
+ tmp_path.unlink()
155
+
156
+ with urllib.request.urlopen(url) as response:
157
+ total = response.length or 0
158
+ downloaded = 0
159
+ chunk_size = 1024 * 1024
160
+ with tmp_path.open("wb") as f:
161
+ while True:
162
+ chunk = response.read(chunk_size)
163
+ if not chunk:
164
+ break
165
+ f.write(chunk)
166
+ downloaded += len(chunk)
167
+ if total > 0 and downloaded != total:
168
+ raise IOError(f"incomplete download: {downloaded} of {total} bytes")
169
+ tmp_path.replace(dest)
170
+
171
+
172
+ def _parse_text_line(line: str) -> Tuple[str, str]:
173
+ """Parse a LibriSpeech transcript line into (utterance_id, text)."""
174
+ left, _, right = line.partition(" ")
175
+ return left, right.strip()
@@ -34,7 +34,9 @@ class TemplateDataset(BaseDataset):
34
34
  protocol intact.
35
35
  """
36
36
 
37
- def __init__(self, root: Path, speaker: str = "default", download: bool = False) -> None:
37
+ def __init__(
38
+ self, root: Path, speaker: str = "default", download: bool = False
39
+ ) -> None:
38
40
  self.root = Path(root)
39
41
  self.speaker = speaker
40
42
  if download:
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  """Dataset-agnostic utilities."""
4
4
 
5
5
  import random
6
+ from pathlib import Path
6
7
  from typing import Callable, List, Optional, Sequence, Tuple
7
8
 
8
9
  import torch
@@ -10,7 +11,9 @@ import torch
10
11
  from .base import BaseDataset, SentenceLike
11
12
 
12
13
 
13
- def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
14
+ def choose_speakers(
15
+ dataset: BaseDataset, num_sources: int, rng: random.Random
16
+ ) -> List[str]:
14
17
  """Select unique speakers for the requested number of sources.
15
18
 
16
19
  Example:
@@ -89,4 +92,23 @@ def load_dataset_sources(
89
92
  info.append((speaker, utterance_ids))
90
93
 
91
94
  stacked = torch.stack(signals, dim=0)
95
+ if fs is None:
96
+ raise RuntimeError("no audio loaded from dataset sources")
92
97
  return stacked, int(fs), info
98
+
99
+
100
+ def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
101
+ """Load a wav/flac file and return mono audio and sample rate.
102
+
103
+ Example:
104
+ >>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
105
+ """
106
+ import soundfile as sf
107
+
108
+ audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
109
+ audio_t = torch.from_numpy(audio)
110
+ if audio_t.shape[1] > 1:
111
+ audio_t = audio_t.mean(dim=1)
112
+ else:
113
+ audio_t = audio_t.squeeze(1)
114
+ return audio_t, sample_rate
torchrir/dynamic.py CHANGED
@@ -44,7 +44,9 @@ class DynamicConvolver:
44
44
  if self.hop is None:
45
45
  raise ValueError("hop must be provided for hop mode")
46
46
  return _convolve_dynamic_hop(signal, rirs, self.hop)
47
- return _convolve_dynamic_trajectory(signal, rirs, timestamps=self.timestamps, fs=self.fs)
47
+ return _convolve_dynamic_trajectory(
48
+ signal, rirs, timestamps=self.timestamps, fs=self.fs
49
+ )
48
50
 
49
51
 
50
52
  def _convolve_dynamic_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
torchrir/plotting.py CHANGED
@@ -92,7 +92,9 @@ def plot_scene_dynamic(
92
92
  return ax
93
93
 
94
94
 
95
- def _setup_axes(ax: Any | None, room: Room | Sequence[float] | Tensor) -> tuple[Any, Any]:
95
+ def _setup_axes(
96
+ ax: Any | None, room: Room | Sequence[float] | Tensor
97
+ ) -> tuple[Any, Any]:
96
98
  """Create 2D/3D axes based on room dimension."""
97
99
  import matplotlib.pyplot as plt
98
100
 
@@ -131,8 +133,9 @@ def _draw_room_2d(ax: Any, size: Tensor) -> None:
131
133
  """Draw a 2D rectangular room."""
132
134
  import matplotlib.patches as patches
133
135
 
134
- rect = patches.Rectangle((0.0, 0.0), size[0].item(), size[1].item(),
135
- fill=False, edgecolor="black")
136
+ rect = patches.Rectangle(
137
+ (0.0, 0.0), size[0].item(), size[1].item(), fill=False, edgecolor="black"
138
+ )
136
139
  ax.add_patch(rect)
137
140
  ax.set_xlim(0, size[0].item())
138
141
  ax.set_ylim(0, size[1].item())
@@ -186,7 +189,9 @@ def _draw_room_3d(ax: Any, size: Tensor) -> None:
186
189
  ax.set_zlabel("z")
187
190
 
188
191
 
189
- def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None) -> Tensor:
192
+ def _extract_positions(
193
+ entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None
194
+ ) -> Tensor:
190
195
  """Extract positions from Source/MicrophoneArray or raw tensor."""
191
196
  if isinstance(entity, (Source, MicrophoneArray)):
192
197
  pos = entity.positions
@@ -211,7 +216,9 @@ def _scatter_positions(
211
216
  return
212
217
  dim = positions.shape[1]
213
218
  if dim == 2:
214
- ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker, color=color)
219
+ ax.scatter(
220
+ positions[:, 0], positions[:, 1], label=label, marker=marker, color=color
221
+ )
215
222
  else:
216
223
  ax.scatter(
217
224
  positions[:, 0],
@@ -300,4 +307,4 @@ def _is_moving(traj: Tensor, positions: Tensor, *, tol: float = 1e-6) -> bool:
300
307
  if traj.numel() == 0:
301
308
  return False
302
309
  pos0 = positions.unsqueeze(0).expand_as(traj)
303
- return torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item()
310
+ return bool(torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item())
@@ -127,7 +127,10 @@ def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
127
127
  return pos
128
128
 
129
129
 
130
- def _traj_steps(src_traj: Optional[torch.Tensor | Sequence], mic_traj: Optional[torch.Tensor | Sequence]) -> int:
130
+ def _traj_steps(
131
+ src_traj: Optional[torch.Tensor | Sequence],
132
+ mic_traj: Optional[torch.Tensor | Sequence],
133
+ ) -> int:
131
134
  """Infer the number of trajectory steps."""
132
135
  if src_traj is not None:
133
136
  return int(_to_cpu(src_traj).shape[0])
torchrir/room.py CHANGED
@@ -65,7 +65,7 @@ class Source:
65
65
  """Source container with positions and optional orientation.
66
66
 
67
67
  Example:
68
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
68
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
69
69
  """
70
70
 
71
71
  positions: Tensor
@@ -82,24 +82,6 @@ class Source:
82
82
  """Return a new Source with updated fields."""
83
83
  return replace(self, **kwargs)
84
84
 
85
- @classmethod
86
- def positions(
87
- cls,
88
- positions: Sequence[Sequence[float]] | Tensor,
89
- *,
90
- orientation: Optional[Sequence[float] | Tensor] = None,
91
- device: Optional[torch.device | str] = None,
92
- dtype: Optional[torch.dtype] = None,
93
- ) -> "Source":
94
- """Construct a Source from positions.
95
-
96
- Example:
97
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
98
- """
99
- return cls.from_positions(
100
- positions, orientation=orientation, device=device, dtype=dtype
101
- )
102
-
103
85
  @classmethod
104
86
  def from_positions(
105
87
  cls,
@@ -122,7 +104,7 @@ class MicrophoneArray:
122
104
  """Microphone array container.
123
105
 
124
106
  Example:
125
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
107
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
126
108
  """
127
109
 
128
110
  positions: Tensor
@@ -139,24 +121,6 @@ class MicrophoneArray:
139
121
  """Return a new MicrophoneArray with updated fields."""
140
122
  return replace(self, **kwargs)
141
123
 
142
- @classmethod
143
- def positions(
144
- cls,
145
- positions: Sequence[Sequence[float]] | Tensor,
146
- *,
147
- orientation: Optional[Sequence[float] | Tensor] = None,
148
- device: Optional[torch.device | str] = None,
149
- dtype: Optional[torch.dtype] = None,
150
- ) -> "MicrophoneArray":
151
- """Construct a MicrophoneArray from positions.
152
-
153
- Example:
154
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
155
- """
156
- return cls.from_positions(
157
- positions, orientation=orientation, device=device, dtype=dtype
158
- )
159
-
160
124
  @classmethod
161
125
  def from_positions(
162
126
  cls,
torchrir/scene_utils.py CHANGED
@@ -32,7 +32,9 @@ def sample_positions(
32
32
  return torch.tensor(coords, dtype=torch.float32)
33
33
 
34
34
 
35
- def linear_trajectory(start: torch.Tensor, end: torch.Tensor, steps: int) -> torch.Tensor:
35
+ def linear_trajectory(
36
+ start: torch.Tensor, end: torch.Tensor, steps: int
37
+ ) -> torch.Tensor:
36
38
  """Create a linear trajectory between start and end.
37
39
 
38
40
  Example:
@@ -58,7 +60,9 @@ def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.
58
60
  return torch.stack([left, right], dim=0)
59
61
 
60
62
 
61
- def clamp_positions(positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1) -> torch.Tensor:
63
+ def clamp_positions(
64
+ positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1
65
+ ) -> torch.Tensor:
62
66
  """Clamp positions to remain inside the room with a margin.
63
67
 
64
68
  Example:
torchrir/signal.py CHANGED
@@ -117,9 +117,9 @@ def _convolve_dynamic_rir_trajectory(
117
117
  else:
118
118
  step_fs = n_samples / t_steps
119
119
  ts_dtype = torch.float32 if signal.device.type == "mps" else torch.float64
120
- w_ini = (torch.arange(t_steps, device=signal.device, dtype=ts_dtype) * step_fs).to(
121
- torch.long
122
- )
120
+ w_ini = (
121
+ torch.arange(t_steps, device=signal.device, dtype=ts_dtype) * step_fs
122
+ ).to(torch.long)
123
123
 
124
124
  w_ini = torch.cat(
125
125
  [w_ini, torch.tensor([n_samples], device=signal.device, dtype=torch.long)]
@@ -132,14 +132,18 @@ def _convolve_dynamic_rir_trajectory(
132
132
  )
133
133
 
134
134
  max_len = int(w_len.max().item())
135
- segments = torch.zeros((t_steps, n_src, max_len), dtype=signal.dtype, device=signal.device)
135
+ segments = torch.zeros(
136
+ (t_steps, n_src, max_len), dtype=signal.dtype, device=signal.device
137
+ )
136
138
  for t in range(t_steps):
137
139
  start = int(w_ini[t].item())
138
140
  end = int(w_ini[t + 1].item())
139
141
  if end > start:
140
142
  segments[t, :, : end - start] = signal[:, start:end]
141
143
 
142
- out = torch.zeros((n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device)
144
+ out = torch.zeros(
145
+ (n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device
146
+ )
143
147
 
144
148
  for t in range(t_steps):
145
149
  seg_len = int(w_len[t].item())
@@ -166,7 +170,9 @@ def _convolve_dynamic_rir_trajectory_batched(
166
170
  """GPU-friendly batched trajectory convolution using FFT."""
167
171
  n_samples = signal.shape[1]
168
172
  t_steps, n_src, n_mic, rir_len = rirs.shape
169
- out = torch.zeros((n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device)
173
+ out = torch.zeros(
174
+ (n_mic, n_samples + rir_len - 1), dtype=signal.dtype, device=signal.device
175
+ )
170
176
 
171
177
  for t0 in range(0, t_steps, chunk_size):
172
178
  t1 = min(t0 + chunk_size, t_steps)
@@ -174,7 +180,9 @@ def _convolve_dynamic_rir_trajectory_batched(
174
180
  max_len = int(lengths.max().item())
175
181
  if max_len == 0:
176
182
  continue
177
- segments = torch.zeros((t1 - t0, n_src, max_len), dtype=signal.dtype, device=signal.device)
183
+ segments = torch.zeros(
184
+ (t1 - t0, n_src, max_len), dtype=signal.dtype, device=signal.device
185
+ )
178
186
  for idx, t in enumerate(range(t0, t1)):
179
187
  start = int(w_ini[t].item())
180
188
  end = int(w_ini[t + 1].item())
@@ -190,7 +198,9 @@ def _convolve_dynamic_rir_trajectory_batched(
190
198
  dtype=signal.dtype,
191
199
  device=signal.device,
192
200
  )
193
- conv = torch.fft.irfft(seg_f[:, :, None, :] * rir_f, n=fft_len, dim=-1, out=conv_out)
201
+ conv = torch.fft.irfft(
202
+ seg_f[:, :, None, :] * rir_f, n=fft_len, dim=-1, out=conv_out
203
+ )
194
204
  conv = conv[..., :conv_len]
195
205
  conv_sum = conv.sum(dim=1)
196
206
 
@@ -199,7 +209,9 @@ def _convolve_dynamic_rir_trajectory_batched(
199
209
  if seg_len == 0:
200
210
  continue
201
211
  start = int(w_ini[t].item())
202
- out[:, start : start + seg_len + rir_len - 1] += conv_sum[idx, :, : seg_len + rir_len - 1]
212
+ out[:, start : start + seg_len + rir_len - 1] += conv_sum[
213
+ idx, :, : seg_len + rir_len - 1
214
+ ]
203
215
 
204
216
  return out.squeeze(0) if n_mic == 1 else out
205
217
 
@@ -221,7 +233,9 @@ def _ensure_static_rirs(rirs: Tensor) -> Tensor:
221
233
  return rirs.view(1, rirs.shape[0], rirs.shape[1])
222
234
  if rirs.ndim == 3:
223
235
  return rirs
224
- raise ValueError("rirs must have shape (rir_len,), (n_mic, rir_len), or (n_src, n_mic, rir_len)")
236
+ raise ValueError(
237
+ "rirs must have shape (rir_len,), (n_mic, rir_len), or (n_src, n_mic, rir_len)"
238
+ )
225
239
 
226
240
 
227
241
  def _ensure_dynamic_rirs(rirs: Tensor, signal: Tensor) -> Tensor:
torchrir/simulators.py CHANGED
@@ -18,7 +18,9 @@ from .scene import Scene
18
18
  class RIRSimulator(Protocol):
19
19
  """Strategy interface for RIR simulation backends."""
20
20
 
21
- def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
21
+ def simulate(
22
+ self, scene: Scene, config: SimulationConfig | None = None
23
+ ) -> RIRResult:
22
24
  """Run a simulation and return the result."""
23
25
 
24
26
 
@@ -30,7 +32,9 @@ class ISMSimulator:
30
32
  >>> result = ISMSimulator().simulate(scene, config)
31
33
  """
32
34
 
33
- def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
35
+ def simulate(
36
+ self, scene: Scene, config: SimulationConfig | None = None
37
+ ) -> RIRResult:
34
38
  scene.validate()
35
39
  cfg = config or default_config()
36
40
  if scene.is_dynamic():
@@ -71,7 +75,9 @@ class RayTracingSimulator:
71
75
  reuse Scene/SimulationConfig for inputs and keep output shape parity.
72
76
  """
73
77
 
74
- def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
78
+ def simulate(
79
+ self, scene: Scene, config: SimulationConfig | None = None
80
+ ) -> RIRResult:
75
81
  raise NotImplementedError("RayTracingSimulator is not implemented yet")
76
82
 
77
83
 
@@ -86,5 +92,7 @@ class FDTDSimulator:
86
92
  RIRResult with the same metadata contract as ISM.
87
93
  """
88
94
 
89
- def simulate(self, scene: Scene, config: SimulationConfig | None = None) -> RIRResult:
95
+ def simulate(
96
+ self, scene: Scene, config: SimulationConfig | None = None
97
+ ) -> RIRResult:
90
98
  raise NotImplementedError("FDTDSimulator is not implemented yet")
torchrir/utils.py CHANGED
@@ -15,7 +15,7 @@ _DEF_SPEED_OF_SOUND = 343.0
15
15
 
16
16
 
17
17
  def as_tensor(
18
- value: Tensor | Iterable[float] | float | int,
18
+ value: Tensor | Iterable[float] | Iterable[Iterable[float]] | float | int,
19
19
  *,
20
20
  device: Optional[torch.device | str] = None,
21
21
  dtype: Optional[torch.dtype] = None,
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchrir
3
+ Version: 0.2.0
4
+ Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
5
+ Project-URL: Repository, https://github.com/taishi-n/torchrir
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ License-File: NOTICE
10
+ Requires-Dist: numpy>=2.2.6
11
+ Requires-Dist: torch>=2.10.0
12
+ Dynamic: license-file
13
+
14
+ # TorchRIR
15
+
16
+ PyTorch-based room impulse response (RIR) simulation toolkit focused on a clean, modern API with GPU support.
17
+ This project has been substantially assisted by AI using Codex.
18
+
19
+ ## Installation
20
+ ```bash
21
+ pip install torchrir
22
+ ```
23
+
24
+ ## Examples
25
+ - `examples/static.py`: fixed sources/mics with binaural output.
26
+ `uv run python examples/static.py --plot`
27
+ - `examples/dynamic_src.py`: moving sources, fixed mics.
28
+ `uv run python examples/dynamic_src.py --plot`
29
+ - `examples/dynamic_mic.py`: fixed sources, moving mics.
30
+ `uv run python examples/dynamic_mic.py --plot`
31
+ - `examples/cli.py`: unified CLI for static/dynamic scenes, JSON/YAML configs.
32
+ `uv run python examples/cli.py --mode static --plot`
33
+ - `examples/cmu_arctic_dynamic_dataset.py`: small dynamic dataset generator (fixed room/mics, randomized source motion).
34
+ `uv run python examples/cmu_arctic_dynamic_dataset.py --num-scenes 4 --num-sources 2`
35
+ - `examples/benchmark_device.py`: CPU/GPU benchmark for RIR simulation.
36
+ `uv run python examples/benchmark_device.py --dynamic`
37
+
38
+ ## Core API Overview
39
+ - Geometry: `Room`, `Source`, `MicrophoneArray`
40
+ - Static RIR: `simulate_rir`
41
+ - Dynamic RIR: `simulate_dynamic_rir`
42
+ - Dynamic convolution: `DynamicConvolver`
43
+ - Metadata export: `build_metadata`, `save_metadata_json`
44
+
45
+ ```python
46
+ from torchrir import DynamicConvolver, MicrophoneArray, Room, Source, simulate_rir
47
+
48
+ room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
49
+ sources = Source.from_positions([[1.0, 2.0, 1.5]])
50
+ mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
51
+
52
+ rir = simulate_rir(room=room, sources=sources, mics=mics, max_order=6, tmax=0.3)
53
+ # For dynamic scenes, compute rirs with simulate_dynamic_rir and convolve:
54
+ # y = DynamicConvolver(mode="trajectory").convolve(signal, rirs)
55
+ ```
56
+
57
+ For detailed documentation, see the docs under `docs/` and Read the Docs.
58
+
59
+ ## Future Work
60
+ - Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
61
+ - CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
62
+ - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
63
+ - Add regression tests comparing generated RIRs against gpuRIR outputs.
64
+
65
+ ## Related Libraries
66
+ - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR)
67
+ - [Cross3D](https://github.com/DavidDiazGuerra/Cross3D)
68
+ - [pyroomacoustics](https://github.com/LCAV/pyroomacoustics)
69
+ - [das-generator](https://github.com/ehabets/das-generator)
70
+ - [rir-generator](https://github.com/audiolabs/rir-generator)
@@ -0,0 +1,30 @@
1
+ torchrir/__init__.py,sha256=MTlouAErvB7IyM4pcmhEN1U0KQsZBglWxUoHxnZDa5U,2615
2
+ torchrir/animation.py,sha256=x3Y-BLz3J6DQNmoDIjbMEgGfng2yavJFLyQEmRCSpQU,6391
3
+ torchrir/config.py,sha256=PsZdDIS3p4jepeNSHyd69aSD9QlOEdpG9v1SAXlZ_Fg,2295
4
+ torchrir/core.py,sha256=Ug5thts1rXvCpdq9twVHz72oWygmk5J6gliloozHKL4,31704
5
+ torchrir/directivity.py,sha256=v_t37YgeXF_IYzbnrk0TCs1npb_0yKR7zHiG8XV3V4w,1259
6
+ torchrir/dynamic.py,sha256=01JHMxhORdcz93J-YaMIeSLo7k2tHrZke8llPHHXwZg,2153
7
+ torchrir/logging_utils.py,sha256=s4jDSSDoHT0HKeplDUpGMsdeBij4eibLSpaaAPzkB68,2146
8
+ torchrir/metadata.py,sha256=cwoXrr_yE2bQRUPnJe6p7POMPCWa9_oabCtp--WqBE8,6958
9
+ torchrir/plotting.py,sha256=TM1LxPitZq5KXdNe1GfUCOnzFOzerGhWIFblzIz142A,8170
10
+ torchrir/plotting_utils.py,sha256=Kg3TCLqEq_lxVQkYHI6vyz_6oG3Ic_Z8H9gZN-39QeI,5180
11
+ torchrir/results.py,sha256=-HczEfr2u91BNb1xbrIGKCj0G3yzy7l_fmUMUeKbGRw,614
12
+ torchrir/room.py,sha256=zFnEzw0Rr1NP9IUc3iNTInyoq6t3X-0yOyUtDnsLSPk,4325
13
+ torchrir/scene.py,sha256=GuHuCspakAUOT81_ArTqaZbmBX0ApoJuCKTaZ21wGis,2435
14
+ torchrir/scene_utils.py,sha256=2La5dtjxYdINX315VXRRJMJK9oaR2rY0xHmDLjZma8M,2140
15
+ torchrir/signal.py,sha256=M0BpKDBqrfOmCHIJ_dvl-C3uKdFpXLDqtSIU115jsME,8383
16
+ torchrir/simulators.py,sha256=NCl8Ptv2TGdBpNLwAb3nigT77On-BLIANtc2ivgKasw,3131
17
+ torchrir/utils.py,sha256=2oE-JzAtkW5qdRds2Y5R5lbSyNZl_9piFXd6xOLzjxM,10680
18
+ torchrir/datasets/__init__.py,sha256=NS4zQas9YdsuDv8KQtTKmIJmS6mxMRxQk2xGzglbgUw,853
19
+ torchrir/datasets/base.py,sha256=LfdXO-NGCBtzaAqAeVxo8XuV5ieU6Vl91woqAHymsT8,1970
20
+ torchrir/datasets/cmu_arctic.py,sha256=7IFv33RBBu044kTMO6nKUmziml2gjILUgnpL262rAU8,6593
21
+ torchrir/datasets/collate.py,sha256=gZfaHog0gtb8Avg6qsDZ1m4yoKkYkcuwmty1RtLYhhI,2542
22
+ torchrir/datasets/librispeech.py,sha256=XKlAm0Z0coipKuqR9Z8X8l9puXVYz7zb6yE3PCuMUrI,6019
23
+ torchrir/datasets/template.py,sha256=pHAKj5E7Gehfk9pqdTsFQjiDV1OK3hSZJIbYutd-E4c,2090
24
+ torchrir/datasets/utils.py,sha256=OCYd7Dbr2hsqBbiHE1LHPMYdqwe2YfDw0tpRMfND0Og,3790
25
+ torchrir-0.2.0.dist-info/licenses/LICENSE,sha256=5vS_7WTsMEw_QQHEPQ_WCwovJXEgmxoEwcwOI-9VbXI,10766
26
+ torchrir-0.2.0.dist-info/licenses/NOTICE,sha256=SRs_q-ZqoVF9_YuuedZOvVBk01jV7YQAeF8rRvlRg0s,118
27
+ torchrir-0.2.0.dist-info/METADATA,sha256=0f5EhK2SWmP_RxZ2O8y7AIDP_dFaKaf9X_n0mt7045k,3077
28
+ torchrir-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
29
+ torchrir-0.2.0.dist-info/top_level.txt,sha256=aIFwntowJjvm7rZk480HymC3ipDo1g-9hEbNY1wF-Oo,9
30
+ torchrir-0.2.0.dist-info/RECORD,,