torchrir 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/__init__.py CHANGED
@@ -18,6 +18,11 @@ from .datasets import (
18
18
  CmuArcticDataset,
19
19
  CmuArcticSentence,
20
20
  choose_speakers,
21
+ CollateBatch,
22
+ collate_dataset_items,
23
+ DatasetItem,
24
+ LibriSpeechDataset,
25
+ LibriSpeechSentence,
21
26
  list_cmu_arctic_speakers,
22
27
  SentenceLike,
23
28
  load_dataset_sources,
@@ -61,6 +66,11 @@ __all__ = [
61
66
  "CmuArcticDataset",
62
67
  "CmuArcticSentence",
63
68
  "choose_speakers",
69
+ "CollateBatch",
70
+ "collate_dataset_items",
71
+ "DatasetItem",
72
+ "LibriSpeechDataset",
73
+ "LibriSpeechSentence",
64
74
  "DynamicConvolver",
65
75
  "estimate_beta_from_t60",
66
76
  "estimate_t60_from_beta",
torchrir/core.py CHANGED
@@ -262,6 +262,11 @@ def simulate_dynamic_rir(
262
262
 
263
263
  src_traj = as_tensor(src_traj, device=device, dtype=dtype)
264
264
  mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
265
+ device, dtype = infer_device_dtype(
266
+ src_traj, mic_traj, room.size, device=device, dtype=dtype
267
+ )
268
+ src_traj = as_tensor(src_traj, device=device, dtype=dtype)
269
+ mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
265
270
 
266
271
  if src_traj.ndim == 2:
267
272
  src_traj = src_traj.unsqueeze(1)
@@ -274,24 +279,95 @@ def simulate_dynamic_rir(
274
279
  if src_traj.shape[0] != mic_traj.shape[0]:
275
280
  raise ValueError("src_traj and mic_traj must have the same time length")
276
281
 
277
- t_steps = src_traj.shape[0]
278
- rirs = []
279
- for t_idx in range(t_steps):
280
- rir = simulate_rir(
281
- room=room,
282
- sources=src_traj[t_idx],
283
- mics=mic_traj[t_idx],
284
- max_order=max_order,
285
- nsample=nsample,
286
- tmax=tmax,
287
- directivity=directivity,
288
- orientation=orientation,
289
- config=config,
290
- device=device,
291
- dtype=dtype,
282
+ if not isinstance(room, Room):
283
+ raise TypeError("room must be a Room instance")
284
+ if nsample is None:
285
+ if tmax is None:
286
+ raise ValueError("nsample or tmax must be provided")
287
+ nsample = int(math.ceil(tmax * room.fs))
288
+ if nsample <= 0:
289
+ raise ValueError("nsample must be positive")
290
+ if max_order < 0:
291
+ raise ValueError("max_order must be non-negative")
292
+
293
+ room_size = as_tensor(room.size, device=device, dtype=dtype)
294
+ room_size = ensure_dim(room_size)
295
+ dim = room_size.numel()
296
+ if src_traj.shape[2] != dim:
297
+ raise ValueError("src_traj must match room dimension")
298
+ if mic_traj.shape[2] != dim:
299
+ raise ValueError("mic_traj must match room dimension")
300
+
301
+ src_ori = None
302
+ mic_ori = None
303
+ if orientation is not None:
304
+ if isinstance(orientation, (list, tuple)):
305
+ if len(orientation) != 2:
306
+ raise ValueError("orientation tuple must have length 2")
307
+ src_ori, mic_ori = orientation
308
+ else:
309
+ src_ori = orientation
310
+ mic_ori = orientation
311
+ if src_ori is not None:
312
+ src_ori = as_tensor(src_ori, device=device, dtype=dtype)
313
+ if mic_ori is not None:
314
+ mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
315
+
316
+ beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
317
+ beta = _validate_beta(beta, dim)
318
+ n_vec = _image_source_indices(max_order, dim, device=device, nb_img=None)
319
+ refl = _reflection_coefficients(n_vec, beta)
320
+
321
+ src_pattern, mic_pattern = split_directivity(directivity)
322
+ mic_dir = None
323
+ if mic_pattern != "omni":
324
+ if mic_ori is None:
325
+ raise ValueError("mic orientation required for non-omni directivity")
326
+ mic_dir = orientation_to_unit(mic_ori, dim)
327
+
328
+ n_src = src_traj.shape[1]
329
+ n_mic = mic_traj.shape[1]
330
+ rirs = torch.zeros((src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype)
331
+ fdl = cfg.frac_delay_length
332
+ fdl2 = (fdl - 1) // 2
333
+ img_chunk = cfg.image_chunk_size
334
+ if img_chunk <= 0:
335
+ img_chunk = n_vec.shape[0]
336
+
337
+ src_dirs = None
338
+ if src_pattern != "omni":
339
+ if src_ori is None:
340
+ raise ValueError("source orientation required for non-omni directivity")
341
+ src_dirs = orientation_to_unit(src_ori, dim)
342
+ if src_dirs.ndim == 1:
343
+ src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
344
+ if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
345
+ raise ValueError("source orientation must match number of sources")
346
+
347
+ for start in range(0, n_vec.shape[0], img_chunk):
348
+ end = min(start + img_chunk, n_vec.shape[0])
349
+ n_vec_chunk = n_vec[start:end]
350
+ refl_chunk = refl[start:end]
351
+ sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
352
+ src_traj,
353
+ mic_traj,
354
+ room_size,
355
+ n_vec_chunk,
356
+ refl_chunk,
357
+ room,
358
+ fdl2,
359
+ src_pattern=src_pattern,
360
+ mic_pattern=mic_pattern,
361
+ src_dirs=src_dirs,
362
+ mic_dir=mic_dir,
292
363
  )
293
- rirs.append(rir)
294
- return torch.stack(rirs, dim=0)
364
+ t_steps = src_traj.shape[0]
365
+ sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
366
+ attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
367
+ rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
368
+ _accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)
369
+
370
+ return rirs
295
371
 
296
372
 
297
373
  def _prepare_entities(
@@ -508,6 +584,54 @@ def _compute_image_contributions_batch(
508
584
  return sample, attenuation
509
585
 
510
586
 
587
+ def _compute_image_contributions_time_batch(
588
+ src_traj: Tensor,
589
+ mic_traj: Tensor,
590
+ room_size: Tensor,
591
+ n_vec: Tensor,
592
+ refl: Tensor,
593
+ room: Room,
594
+ fdl2: int,
595
+ *,
596
+ src_pattern: str,
597
+ mic_pattern: str,
598
+ src_dirs: Optional[Tensor],
599
+ mic_dir: Optional[Tensor],
600
+ ) -> Tuple[Tensor, Tensor]:
601
+ """Compute samples/attenuation for all time steps in batch."""
602
+ sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_traj.dtype)
603
+ n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_traj.dtype)
604
+ base = 2.0 * room_size * n
605
+ img = base[None, None, :, :] + sign[None, None, :, :] * src_traj[:, :, None, :]
606
+ vec = mic_traj[:, None, :, None, :] - img[:, :, None, :, :]
607
+ dist = torch.linalg.norm(vec, dim=-1)
608
+ dist = torch.clamp(dist, min=1e-6)
609
+ time = dist / room.c
610
+ time = time + (fdl2 / room.fs)
611
+ sample = time * room.fs
612
+
613
+ gain = refl.view(1, 1, 1, -1)
614
+ if src_pattern != "omni":
615
+ if src_dirs is None:
616
+ raise ValueError("source orientation required for non-omni directivity")
617
+ src_dirs_b = src_dirs[None, :, None, None, :]
618
+ cos_theta = _cos_between(vec, src_dirs_b)
619
+ gain = gain * directivity_gain(src_pattern, cos_theta)
620
+ if mic_pattern != "omni":
621
+ if mic_dir is None:
622
+ raise ValueError("mic orientation required for non-omni directivity")
623
+ mic_dir_b = (
624
+ mic_dir[None, None, :, None, :]
625
+ if mic_dir.ndim == 2
626
+ else mic_dir.view(1, 1, 1, 1, -1)
627
+ )
628
+ cos_theta = _cos_between(-vec, mic_dir_b)
629
+ gain = gain * directivity_gain(mic_pattern, cos_theta)
630
+
631
+ attenuation = gain / dist
632
+ return sample, attenuation
633
+
634
+
511
635
  def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
512
636
  """Pick the correct orientation vector for a given entity index."""
513
637
  if orientation.ndim == 0:
@@ -1,14 +1,15 @@
1
1
  """Dataset helpers for torchrir."""
2
2
 
3
- from .base import BaseDataset, SentenceLike
4
- from .utils import choose_speakers, load_dataset_sources
3
+ from .base import BaseDataset, DatasetItem, SentenceLike
4
+ from .utils import choose_speakers, load_dataset_sources, load_wav_mono
5
+ from .collate import CollateBatch, collate_dataset_items
5
6
  from .template import TemplateDataset, TemplateSentence
7
+ from .librispeech import LibriSpeechDataset, LibriSpeechSentence
6
8
 
7
9
  from .cmu_arctic import (
8
10
  CmuArcticDataset,
9
11
  CmuArcticSentence,
10
12
  list_cmu_arctic_speakers,
11
- load_wav_mono,
12
13
  save_wav,
13
14
  )
14
15
 
@@ -17,6 +18,9 @@ __all__ = [
17
18
  "CmuArcticDataset",
18
19
  "CmuArcticSentence",
19
20
  "choose_speakers",
21
+ "DatasetItem",
22
+ "CollateBatch",
23
+ "collate_dataset_items",
20
24
  "list_cmu_arctic_speakers",
21
25
  "SentenceLike",
22
26
  "load_dataset_sources",
@@ -24,4 +28,6 @@ __all__ = [
24
28
  "save_wav",
25
29
  "TemplateDataset",
26
30
  "TemplateSentence",
31
+ "LibriSpeechDataset",
32
+ "LibriSpeechSentence",
27
33
  ]
torchrir/datasets/base.py CHANGED
@@ -2,9 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  """Dataset protocol definitions."""
4
4
 
5
- from typing import Protocol, Sequence, Tuple
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Protocol, Sequence, Tuple
6
7
 
7
8
  import torch
9
+ from torch.utils.data import Dataset
8
10
 
9
11
 
10
12
  class SentenceLike(Protocol):
@@ -14,14 +16,52 @@ class SentenceLike(Protocol):
14
16
  text: str
15
17
 
16
18
 
17
- class BaseDataset(Protocol):
18
- """Protocol for datasets used in torchrir examples and tools."""
19
+ @dataclass(frozen=True)
20
+ class DatasetItem:
21
+ """Dataset item for DataLoader consumption."""
22
+
23
+ audio: torch.Tensor
24
+ sample_rate: int
25
+ utterance_id: str
26
+ text: Optional[str] = None
27
+ speaker: Optional[str] = None
28
+
29
+
30
+ class BaseDataset(Dataset[DatasetItem]):
31
+ """Base dataset class compatible with torch.utils.data.Dataset."""
32
+
33
+ _sentences_cache: Optional[list[SentenceLike]] = None
19
34
 
20
35
  def list_speakers(self) -> list[str]:
21
36
  """Return available speaker IDs."""
37
+ raise NotImplementedError
22
38
 
23
39
  def available_sentences(self) -> Sequence[SentenceLike]:
24
40
  """Return sentence entries that have audio available."""
41
+ raise NotImplementedError
25
42
 
26
43
  def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
27
44
  """Load audio for an utterance and return (audio, sample_rate)."""
45
+ raise NotImplementedError
46
+
47
+ def __len__(self) -> int:
48
+ return len(self._get_sentences())
49
+
50
+ def __getitem__(self, idx: int) -> DatasetItem:
51
+ sentences = self._get_sentences()
52
+ sentence = sentences[idx]
53
+ audio, sample_rate = self.load_wav(sentence.utterance_id)
54
+ speaker = getattr(self, "speaker", None)
55
+ text = getattr(sentence, "text", None)
56
+ return DatasetItem(
57
+ audio=audio,
58
+ sample_rate=sample_rate,
59
+ utterance_id=sentence.utterance_id,
60
+ text=text,
61
+ speaker=speaker,
62
+ )
63
+
64
+ def _get_sentences(self) -> list[SentenceLike]:
65
+ if self._sentences_cache is None:
66
+ self._sentences_cache = list(self.available_sentences())
67
+ return self._sentences_cache
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """CMU ARCTIC dataset helpers."""
4
4
 
5
+ import logging
5
6
  import tarfile
6
7
  import urllib.request
7
8
  from dataclasses import dataclass
@@ -9,7 +10,9 @@ from pathlib import Path
9
10
  from typing import List, Tuple
10
11
 
11
12
  import torch
12
- import logging
13
+
14
+ from .base import BaseDataset
15
+ from .utils import load_wav_mono
13
16
 
14
17
  BASE_URL = "http://www.festvox.org/cmu_arctic/packed"
15
18
  VALID_SPEAKERS = {
@@ -49,7 +52,7 @@ class CmuArcticSentence:
49
52
  text: str
50
53
 
51
54
 
52
- class CmuArcticDataset:
55
+ class CmuArcticDataset(BaseDataset):
53
56
  """CMU ARCTIC dataset loader.
54
57
 
55
58
  Example:
@@ -191,23 +194,6 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
191
194
  return utterance, text
192
195
 
193
196
 
194
- def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
195
- """Load a wav file and return mono audio and sample rate.
196
-
197
- Example:
198
- >>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
199
- """
200
- import soundfile as sf
201
-
202
- audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
203
- audio_t = torch.from_numpy(audio)
204
- if audio_t.shape[1] > 1:
205
- audio_t = audio_t.mean(dim=1)
206
- else:
207
- audio_t = audio_t.squeeze(1)
208
- return audio_t, sample_rate
209
-
210
-
211
197
  def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
212
198
  """Save a mono or multi-channel wav to disk.
213
199
 
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ """Collate helpers for DataLoader usage."""
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Iterable, List, Optional
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ from .base import DatasetItem
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class CollateBatch:
16
+ """Collated batch of dataset items.
17
+
18
+ Attributes:
19
+ audio: Padded audio tensor of shape (batch, max_len).
20
+ lengths: Original lengths for each item.
21
+ sample_rate: Sample rate shared across the batch.
22
+ utterance_ids: Utterance IDs per item.
23
+ texts: Optional text per item.
24
+ speakers: Optional speaker IDs per item.
25
+ metadata: Optional per-item metadata (pass-through).
26
+ """
27
+
28
+ audio: Tensor
29
+ lengths: Tensor
30
+ sample_rate: int
31
+ utterance_ids: list[str]
32
+ texts: list[Optional[str]]
33
+ speakers: list[Optional[str]]
34
+ metadata: Optional[list[Any]] = None
35
+
36
+
37
+ def collate_dataset_items(
38
+ items: Iterable[DatasetItem],
39
+ *,
40
+ pad_value: float = 0.0,
41
+ keep_metadata: bool = False,
42
+ ) -> CollateBatch:
43
+ """Collate DatasetItem entries into a padded batch.
44
+
45
+ Args:
46
+ items: Iterable of DatasetItem.
47
+ pad_value: Value used for padding.
48
+ keep_metadata: Preserve item-level metadata field if present.
49
+
50
+ Returns:
51
+ CollateBatch with padded audio and metadata lists.
52
+ """
53
+ batch = list(items)
54
+ if not batch:
55
+ raise ValueError("collate_dataset_items received an empty batch")
56
+
57
+ sample_rate = batch[0].sample_rate
58
+ for item in batch[1:]:
59
+ if item.sample_rate != sample_rate:
60
+ raise ValueError("sample_rate must be consistent within a batch")
61
+
62
+ lengths = torch.tensor([item.audio.numel() for item in batch], dtype=torch.long)
63
+ max_len = int(lengths.max().item())
64
+ audio = torch.full(
65
+ (len(batch), max_len),
66
+ pad_value,
67
+ dtype=batch[0].audio.dtype,
68
+ device=batch[0].audio.device,
69
+ )
70
+
71
+ for idx, item in enumerate(batch):
72
+ audio[idx, : item.audio.numel()] = item.audio
73
+
74
+ utterance_ids = [item.utterance_id for item in batch]
75
+ texts = [item.text for item in batch]
76
+ speakers = [item.speaker for item in batch]
77
+
78
+ metadata: Optional[list[Any]] = None
79
+ if keep_metadata:
80
+ metadata = [getattr(item, "metadata", None) for item in batch]
81
+
82
+ return CollateBatch(
83
+ audio=audio,
84
+ lengths=lengths,
85
+ sample_rate=sample_rate,
86
+ utterance_ids=utterance_ids,
87
+ texts=texts,
88
+ speakers=speakers,
89
+ metadata=metadata,
90
+ )
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ """LibriSpeech dataset helpers."""
4
+
5
+ import logging
6
+ import tarfile
7
+ import urllib.request
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import List, Tuple
11
+
12
+ import torch
13
+
14
+ from .base import BaseDataset
15
+ from .utils import load_wav_mono
16
+
17
+ BASE_URL = "https://www.openslr.org/resources/12"
18
+ VALID_SUBSETS = {
19
+ "dev-clean",
20
+ "dev-other",
21
+ "test-clean",
22
+ "test-other",
23
+ "train-clean-100",
24
+ "train-clean-360",
25
+ "train-other-500",
26
+ }
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class LibriSpeechSentence:
33
+ """Sentence metadata from LibriSpeech."""
34
+
35
+ utterance_id: str
36
+ text: str
37
+ speaker_id: str
38
+ chapter_id: str
39
+
40
+
41
+ class LibriSpeechDataset(BaseDataset):
42
+ """LibriSpeech dataset loader.
43
+
44
+ Example:
45
+ >>> dataset = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100", download=True)
46
+ >>> audio, fs = dataset.load_wav("103-1240-0000")
47
+ """
48
+
49
+ def __init__(
50
+ self, root: Path, subset: str = "train-clean-100", download: bool = False
51
+ ) -> None:
52
+ """Initialize a LibriSpeech dataset handle.
53
+
54
+ Args:
55
+ root: Root directory where the dataset is stored.
56
+ subset: LibriSpeech subset name (e.g., "train-clean-100").
57
+ download: Download and extract if missing.
58
+ """
59
+ if subset not in VALID_SUBSETS:
60
+ raise ValueError(f"unsupported subset: {subset}")
61
+ self.root = Path(root)
62
+ self.subset = subset
63
+ self._archive_name = f"{subset}.tar.gz"
64
+ self._base_dir = self.root / "LibriSpeech"
65
+ self._subset_dir = self._base_dir / subset
66
+
67
+ if download:
68
+ self._download_and_extract()
69
+
70
+ if not self._subset_dir.exists():
71
+ raise FileNotFoundError(
72
+ "dataset not found; run with download=True or place the archive under "
73
+ f"{self.root}"
74
+ )
75
+
76
+ def list_speakers(self) -> List[str]:
77
+ """Return available speaker IDs."""
78
+ if not self._subset_dir.exists():
79
+ return []
80
+ return sorted([p.name for p in self._subset_dir.iterdir() if p.is_dir()])
81
+
82
+ def available_sentences(self) -> List[LibriSpeechSentence]:
83
+ """Return sentences that have a corresponding audio file."""
84
+ sentences: List[LibriSpeechSentence] = []
85
+ for trans_path in self._subset_dir.rglob("*.trans.txt"):
86
+ chapter_dir = trans_path.parent
87
+ speaker_id = chapter_dir.parent.name
88
+ chapter_id = chapter_dir.name
89
+ with trans_path.open("r", encoding="utf-8") as f:
90
+ for line in f:
91
+ line = line.strip()
92
+ if not line:
93
+ continue
94
+ utt_id, text = _parse_text_line(line)
95
+ wav_path = chapter_dir / f"{utt_id}.flac"
96
+ if wav_path.exists():
97
+ sentences.append(
98
+ LibriSpeechSentence(
99
+ utterance_id=utt_id,
100
+ text=text,
101
+ speaker_id=speaker_id,
102
+ chapter_id=chapter_id,
103
+ )
104
+ )
105
+ return sentences
106
+
107
+ def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
108
+ """Load a mono wav for the given utterance ID."""
109
+ speaker_id, chapter_id, _ = utterance_id.split("-", 2)
110
+ path = self._subset_dir / speaker_id / chapter_id / f"{utterance_id}.flac"
111
+ return load_wav_mono(path)
112
+
113
+ def _download_and_extract(self) -> None:
114
+ """Download and extract the subset archive if needed."""
115
+ self.root.mkdir(parents=True, exist_ok=True)
116
+ archive_path = self.root / self._archive_name
117
+ url = f"{BASE_URL}/{self._archive_name}"
118
+
119
+ if not archive_path.exists():
120
+ logger.info("Downloading %s", url)
121
+ _download(url, archive_path)
122
+ if not self._subset_dir.exists():
123
+ logger.info("Extracting %s", archive_path)
124
+ try:
125
+ with tarfile.open(archive_path, "r:gz") as tar:
126
+ tar.extractall(self.root)
127
+ except (tarfile.ReadError, EOFError, OSError) as exc:
128
+ logger.warning("Extraction failed (%s); re-downloading.", exc)
129
+ if archive_path.exists():
130
+ archive_path.unlink()
131
+ _download(url, archive_path)
132
+ with tarfile.open(archive_path, "r:gz") as tar:
133
+ tar.extractall(self.root)
134
+
135
+
136
+ def _download(url: str, dest: Path, retries: int = 1) -> None:
137
+ """Download a file with retry and resume-safe temp file."""
138
+ for attempt in range(retries + 1):
139
+ try:
140
+ _stream_download(url, dest)
141
+ return
142
+ except Exception as exc:
143
+ if dest.exists():
144
+ dest.unlink()
145
+ if attempt >= retries:
146
+ raise
147
+ logger.warning("Download failed (%s); retrying...", exc)
148
+
149
+
150
+ def _stream_download(url: str, dest: Path) -> None:
151
+ """Stream a URL to disk with a progress indicator."""
152
+ tmp_path = dest.with_suffix(dest.suffix + ".part")
153
+ if tmp_path.exists():
154
+ tmp_path.unlink()
155
+
156
+ with urllib.request.urlopen(url) as response:
157
+ total = response.length or 0
158
+ downloaded = 0
159
+ chunk_size = 1024 * 1024
160
+ with tmp_path.open("wb") as f:
161
+ while True:
162
+ chunk = response.read(chunk_size)
163
+ if not chunk:
164
+ break
165
+ f.write(chunk)
166
+ downloaded += len(chunk)
167
+ if total > 0 and downloaded != total:
168
+ raise IOError(f"incomplete download: {downloaded} of {total} bytes")
169
+ tmp_path.replace(dest)
170
+
171
+
172
+ def _parse_text_line(line: str) -> Tuple[str, str]:
173
+ """Parse a LibriSpeech transcript line into (utterance_id, text)."""
174
+ left, _, right = line.partition(" ")
175
+ return left, right.strip()
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  """Dataset-agnostic utilities."""
4
4
 
5
5
  import random
6
+ from pathlib import Path
6
7
  from typing import Callable, List, Optional, Sequence, Tuple
7
8
 
8
9
  import torch
@@ -94,3 +95,20 @@ def load_dataset_sources(
94
95
  if fs is None:
95
96
  raise RuntimeError("no audio loaded from dataset sources")
96
97
  return stacked, int(fs), info
98
+
99
+
100
+ def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
101
+ """Load a wav/flac file and return mono audio and sample rate.
102
+
103
+ Example:
104
+ >>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
105
+ """
106
+ import soundfile as sf
107
+
108
+ audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
109
+ audio_t = torch.from_numpy(audio)
110
+ if audio_t.shape[1] > 1:
111
+ audio_t = audio_t.mean(dim=1)
112
+ else:
113
+ audio_t = audio_t.squeeze(1)
114
+ return audio_t, sample_rate
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchrir
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
5
5
  Project-URL: Repository, https://github.com/taishi-n/torchrir
6
6
  Requires-Python: >=3.10
@@ -59,7 +59,7 @@ For detailed documentation, see the docs under `docs/` and Read the Docs.
59
59
  ## Future Work
60
60
  - Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
61
61
  - CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
62
- - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`).
62
+ - Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
63
63
  - Add regression tests comparing generated RIRs against gpuRIR outputs.
64
64
 
65
65
  ## Related Libraries
@@ -1,7 +1,7 @@
1
- torchrir/__init__.py,sha256=urydbUWuUHPBqmy-9QBaQg8eFGznRamkSjLmPiNvBo0,2383
1
+ torchrir/__init__.py,sha256=MTlouAErvB7IyM4pcmhEN1U0KQsZBglWxUoHxnZDa5U,2615
2
2
  torchrir/animation.py,sha256=x3Y-BLz3J6DQNmoDIjbMEgGfng2yavJFLyQEmRCSpQU,6391
3
3
  torchrir/config.py,sha256=PsZdDIS3p4jepeNSHyd69aSD9QlOEdpG9v1SAXlZ_Fg,2295
4
- torchrir/core.py,sha256=VdljYoCoQoZqD8aYJRnuHEb7uORQjyQysVc8K3RGuao,26826
4
+ torchrir/core.py,sha256=Ug5thts1rXvCpdq9twVHz72oWygmk5J6gliloozHKL4,31704
5
5
  torchrir/directivity.py,sha256=v_t37YgeXF_IYzbnrk0TCs1npb_0yKR7zHiG8XV3V4w,1259
6
6
  torchrir/dynamic.py,sha256=01JHMxhORdcz93J-YaMIeSLo7k2tHrZke8llPHHXwZg,2153
7
7
  torchrir/logging_utils.py,sha256=s4jDSSDoHT0HKeplDUpGMsdeBij4eibLSpaaAPzkB68,2146
@@ -15,14 +15,16 @@ torchrir/scene_utils.py,sha256=2La5dtjxYdINX315VXRRJMJK9oaR2rY0xHmDLjZma8M,2140
15
15
  torchrir/signal.py,sha256=M0BpKDBqrfOmCHIJ_dvl-C3uKdFpXLDqtSIU115jsME,8383
16
16
  torchrir/simulators.py,sha256=NCl8Ptv2TGdBpNLwAb3nigT77On-BLIANtc2ivgKasw,3131
17
17
  torchrir/utils.py,sha256=2oE-JzAtkW5qdRds2Y5R5lbSyNZl_9piFXd6xOLzjxM,10680
18
- torchrir/datasets/__init__.py,sha256=3T55F3fjjRR3j618ubRkMlZnQTxvXaxioFMhygxm7oQ,601
19
- torchrir/datasets/base.py,sha256=mCHLtGOOaD1II1alJpP6ipzkz87l-rh19NgfeLnJbDU,720
20
- torchrir/datasets/cmu_arctic.py,sha256=DrOcawHvOEUnFJRw4qZgwuK1jbL2oQ-Vz_zNodYtpjE,7049
18
+ torchrir/datasets/__init__.py,sha256=NS4zQas9YdsuDv8KQtTKmIJmS6mxMRxQk2xGzglbgUw,853
19
+ torchrir/datasets/base.py,sha256=LfdXO-NGCBtzaAqAeVxo8XuV5ieU6Vl91woqAHymsT8,1970
20
+ torchrir/datasets/cmu_arctic.py,sha256=7IFv33RBBu044kTMO6nKUmziml2gjILUgnpL262rAU8,6593
21
+ torchrir/datasets/collate.py,sha256=gZfaHog0gtb8Avg6qsDZ1m4yoKkYkcuwmty1RtLYhhI,2542
22
+ torchrir/datasets/librispeech.py,sha256=XKlAm0Z0coipKuqR9Z8X8l9puXVYz7zb6yE3PCuMUrI,6019
21
23
  torchrir/datasets/template.py,sha256=pHAKj5E7Gehfk9pqdTsFQjiDV1OK3hSZJIbYutd-E4c,2090
22
- torchrir/datasets/utils.py,sha256=TUfdt_XSB71ztCfzq_gCNrbvPh0Y-O5gkyxUnHWYID0,3227
23
- torchrir-0.1.4.dist-info/licenses/LICENSE,sha256=5vS_7WTsMEw_QQHEPQ_WCwovJXEgmxoEwcwOI-9VbXI,10766
24
- torchrir-0.1.4.dist-info/licenses/NOTICE,sha256=SRs_q-ZqoVF9_YuuedZOvVBk01jV7YQAeF8rRvlRg0s,118
25
- torchrir-0.1.4.dist-info/METADATA,sha256=HNSQV3uXeRYfX9eDb7ZllGAMvWCdtzq9Rn-q0kokkL4,2964
26
- torchrir-0.1.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
27
- torchrir-0.1.4.dist-info/top_level.txt,sha256=aIFwntowJjvm7rZk480HymC3ipDo1g-9hEbNY1wF-Oo,9
28
- torchrir-0.1.4.dist-info/RECORD,,
24
+ torchrir/datasets/utils.py,sha256=OCYd7Dbr2hsqBbiHE1LHPMYdqwe2YfDw0tpRMfND0Og,3790
25
+ torchrir-0.2.0.dist-info/licenses/LICENSE,sha256=5vS_7WTsMEw_QQHEPQ_WCwovJXEgmxoEwcwOI-9VbXI,10766
26
+ torchrir-0.2.0.dist-info/licenses/NOTICE,sha256=SRs_q-ZqoVF9_YuuedZOvVBk01jV7YQAeF8rRvlRg0s,118
27
+ torchrir-0.2.0.dist-info/METADATA,sha256=0f5EhK2SWmP_RxZ2O8y7AIDP_dFaKaf9X_n0mt7045k,3077
28
+ torchrir-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
29
+ torchrir-0.2.0.dist-info/top_level.txt,sha256=aIFwntowJjvm7rZk480HymC3ipDo1g-9hEbNY1wF-Oo,9
30
+ torchrir-0.2.0.dist-info/RECORD,,