torchrir 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +10 -0
- torchrir/core.py +141 -17
- torchrir/datasets/__init__.py +9 -3
- torchrir/datasets/base.py +43 -3
- torchrir/datasets/cmu_arctic.py +5 -19
- torchrir/datasets/collate.py +90 -0
- torchrir/datasets/librispeech.py +175 -0
- torchrir/datasets/utils.py +18 -0
- {torchrir-0.1.4.dist-info → torchrir-0.2.0.dist-info}/METADATA +2 -2
- {torchrir-0.1.4.dist-info → torchrir-0.2.0.dist-info}/RECORD +14 -12
- {torchrir-0.1.4.dist-info → torchrir-0.2.0.dist-info}/WHEEL +0 -0
- {torchrir-0.1.4.dist-info → torchrir-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {torchrir-0.1.4.dist-info → torchrir-0.2.0.dist-info}/licenses/NOTICE +0 -0
- {torchrir-0.1.4.dist-info → torchrir-0.2.0.dist-info}/top_level.txt +0 -0
torchrir/__init__.py
CHANGED
|
@@ -18,6 +18,11 @@ from .datasets import (
|
|
|
18
18
|
CmuArcticDataset,
|
|
19
19
|
CmuArcticSentence,
|
|
20
20
|
choose_speakers,
|
|
21
|
+
CollateBatch,
|
|
22
|
+
collate_dataset_items,
|
|
23
|
+
DatasetItem,
|
|
24
|
+
LibriSpeechDataset,
|
|
25
|
+
LibriSpeechSentence,
|
|
21
26
|
list_cmu_arctic_speakers,
|
|
22
27
|
SentenceLike,
|
|
23
28
|
load_dataset_sources,
|
|
@@ -61,6 +66,11 @@ __all__ = [
|
|
|
61
66
|
"CmuArcticDataset",
|
|
62
67
|
"CmuArcticSentence",
|
|
63
68
|
"choose_speakers",
|
|
69
|
+
"CollateBatch",
|
|
70
|
+
"collate_dataset_items",
|
|
71
|
+
"DatasetItem",
|
|
72
|
+
"LibriSpeechDataset",
|
|
73
|
+
"LibriSpeechSentence",
|
|
64
74
|
"DynamicConvolver",
|
|
65
75
|
"estimate_beta_from_t60",
|
|
66
76
|
"estimate_t60_from_beta",
|
torchrir/core.py
CHANGED
|
@@ -262,6 +262,11 @@ def simulate_dynamic_rir(
|
|
|
262
262
|
|
|
263
263
|
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
264
264
|
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
265
|
+
device, dtype = infer_device_dtype(
|
|
266
|
+
src_traj, mic_traj, room.size, device=device, dtype=dtype
|
|
267
|
+
)
|
|
268
|
+
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
269
|
+
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
265
270
|
|
|
266
271
|
if src_traj.ndim == 2:
|
|
267
272
|
src_traj = src_traj.unsqueeze(1)
|
|
@@ -274,24 +279,95 @@ def simulate_dynamic_rir(
|
|
|
274
279
|
if src_traj.shape[0] != mic_traj.shape[0]:
|
|
275
280
|
raise ValueError("src_traj and mic_traj must have the same time length")
|
|
276
281
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
282
|
+
if not isinstance(room, Room):
|
|
283
|
+
raise TypeError("room must be a Room instance")
|
|
284
|
+
if nsample is None:
|
|
285
|
+
if tmax is None:
|
|
286
|
+
raise ValueError("nsample or tmax must be provided")
|
|
287
|
+
nsample = int(math.ceil(tmax * room.fs))
|
|
288
|
+
if nsample <= 0:
|
|
289
|
+
raise ValueError("nsample must be positive")
|
|
290
|
+
if max_order < 0:
|
|
291
|
+
raise ValueError("max_order must be non-negative")
|
|
292
|
+
|
|
293
|
+
room_size = as_tensor(room.size, device=device, dtype=dtype)
|
|
294
|
+
room_size = ensure_dim(room_size)
|
|
295
|
+
dim = room_size.numel()
|
|
296
|
+
if src_traj.shape[2] != dim:
|
|
297
|
+
raise ValueError("src_traj must match room dimension")
|
|
298
|
+
if mic_traj.shape[2] != dim:
|
|
299
|
+
raise ValueError("mic_traj must match room dimension")
|
|
300
|
+
|
|
301
|
+
src_ori = None
|
|
302
|
+
mic_ori = None
|
|
303
|
+
if orientation is not None:
|
|
304
|
+
if isinstance(orientation, (list, tuple)):
|
|
305
|
+
if len(orientation) != 2:
|
|
306
|
+
raise ValueError("orientation tuple must have length 2")
|
|
307
|
+
src_ori, mic_ori = orientation
|
|
308
|
+
else:
|
|
309
|
+
src_ori = orientation
|
|
310
|
+
mic_ori = orientation
|
|
311
|
+
if src_ori is not None:
|
|
312
|
+
src_ori = as_tensor(src_ori, device=device, dtype=dtype)
|
|
313
|
+
if mic_ori is not None:
|
|
314
|
+
mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
|
|
315
|
+
|
|
316
|
+
beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
|
|
317
|
+
beta = _validate_beta(beta, dim)
|
|
318
|
+
n_vec = _image_source_indices(max_order, dim, device=device, nb_img=None)
|
|
319
|
+
refl = _reflection_coefficients(n_vec, beta)
|
|
320
|
+
|
|
321
|
+
src_pattern, mic_pattern = split_directivity(directivity)
|
|
322
|
+
mic_dir = None
|
|
323
|
+
if mic_pattern != "omni":
|
|
324
|
+
if mic_ori is None:
|
|
325
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
326
|
+
mic_dir = orientation_to_unit(mic_ori, dim)
|
|
327
|
+
|
|
328
|
+
n_src = src_traj.shape[1]
|
|
329
|
+
n_mic = mic_traj.shape[1]
|
|
330
|
+
rirs = torch.zeros((src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype)
|
|
331
|
+
fdl = cfg.frac_delay_length
|
|
332
|
+
fdl2 = (fdl - 1) // 2
|
|
333
|
+
img_chunk = cfg.image_chunk_size
|
|
334
|
+
if img_chunk <= 0:
|
|
335
|
+
img_chunk = n_vec.shape[0]
|
|
336
|
+
|
|
337
|
+
src_dirs = None
|
|
338
|
+
if src_pattern != "omni":
|
|
339
|
+
if src_ori is None:
|
|
340
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
341
|
+
src_dirs = orientation_to_unit(src_ori, dim)
|
|
342
|
+
if src_dirs.ndim == 1:
|
|
343
|
+
src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
|
|
344
|
+
if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
|
|
345
|
+
raise ValueError("source orientation must match number of sources")
|
|
346
|
+
|
|
347
|
+
for start in range(0, n_vec.shape[0], img_chunk):
|
|
348
|
+
end = min(start + img_chunk, n_vec.shape[0])
|
|
349
|
+
n_vec_chunk = n_vec[start:end]
|
|
350
|
+
refl_chunk = refl[start:end]
|
|
351
|
+
sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
|
|
352
|
+
src_traj,
|
|
353
|
+
mic_traj,
|
|
354
|
+
room_size,
|
|
355
|
+
n_vec_chunk,
|
|
356
|
+
refl_chunk,
|
|
357
|
+
room,
|
|
358
|
+
fdl2,
|
|
359
|
+
src_pattern=src_pattern,
|
|
360
|
+
mic_pattern=mic_pattern,
|
|
361
|
+
src_dirs=src_dirs,
|
|
362
|
+
mic_dir=mic_dir,
|
|
292
363
|
)
|
|
293
|
-
|
|
294
|
-
|
|
364
|
+
t_steps = src_traj.shape[0]
|
|
365
|
+
sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
|
|
366
|
+
attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
|
|
367
|
+
rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
|
|
368
|
+
_accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)
|
|
369
|
+
|
|
370
|
+
return rirs
|
|
295
371
|
|
|
296
372
|
|
|
297
373
|
def _prepare_entities(
|
|
@@ -508,6 +584,54 @@ def _compute_image_contributions_batch(
|
|
|
508
584
|
return sample, attenuation
|
|
509
585
|
|
|
510
586
|
|
|
587
|
+
def _compute_image_contributions_time_batch(
|
|
588
|
+
src_traj: Tensor,
|
|
589
|
+
mic_traj: Tensor,
|
|
590
|
+
room_size: Tensor,
|
|
591
|
+
n_vec: Tensor,
|
|
592
|
+
refl: Tensor,
|
|
593
|
+
room: Room,
|
|
594
|
+
fdl2: int,
|
|
595
|
+
*,
|
|
596
|
+
src_pattern: str,
|
|
597
|
+
mic_pattern: str,
|
|
598
|
+
src_dirs: Optional[Tensor],
|
|
599
|
+
mic_dir: Optional[Tensor],
|
|
600
|
+
) -> Tuple[Tensor, Tensor]:
|
|
601
|
+
"""Compute samples/attenuation for all time steps in batch."""
|
|
602
|
+
sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_traj.dtype)
|
|
603
|
+
n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_traj.dtype)
|
|
604
|
+
base = 2.0 * room_size * n
|
|
605
|
+
img = base[None, None, :, :] + sign[None, None, :, :] * src_traj[:, :, None, :]
|
|
606
|
+
vec = mic_traj[:, None, :, None, :] - img[:, :, None, :, :]
|
|
607
|
+
dist = torch.linalg.norm(vec, dim=-1)
|
|
608
|
+
dist = torch.clamp(dist, min=1e-6)
|
|
609
|
+
time = dist / room.c
|
|
610
|
+
time = time + (fdl2 / room.fs)
|
|
611
|
+
sample = time * room.fs
|
|
612
|
+
|
|
613
|
+
gain = refl.view(1, 1, 1, -1)
|
|
614
|
+
if src_pattern != "omni":
|
|
615
|
+
if src_dirs is None:
|
|
616
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
617
|
+
src_dirs_b = src_dirs[None, :, None, None, :]
|
|
618
|
+
cos_theta = _cos_between(vec, src_dirs_b)
|
|
619
|
+
gain = gain * directivity_gain(src_pattern, cos_theta)
|
|
620
|
+
if mic_pattern != "omni":
|
|
621
|
+
if mic_dir is None:
|
|
622
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
623
|
+
mic_dir_b = (
|
|
624
|
+
mic_dir[None, None, :, None, :]
|
|
625
|
+
if mic_dir.ndim == 2
|
|
626
|
+
else mic_dir.view(1, 1, 1, 1, -1)
|
|
627
|
+
)
|
|
628
|
+
cos_theta = _cos_between(-vec, mic_dir_b)
|
|
629
|
+
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
630
|
+
|
|
631
|
+
attenuation = gain / dist
|
|
632
|
+
return sample, attenuation
|
|
633
|
+
|
|
634
|
+
|
|
511
635
|
def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
|
|
512
636
|
"""Pick the correct orientation vector for a given entity index."""
|
|
513
637
|
if orientation.ndim == 0:
|
torchrir/datasets/__init__.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""Dataset helpers for torchrir."""
|
|
2
2
|
|
|
3
|
-
from .base import BaseDataset, SentenceLike
|
|
4
|
-
from .utils import choose_speakers, load_dataset_sources
|
|
3
|
+
from .base import BaseDataset, DatasetItem, SentenceLike
|
|
4
|
+
from .utils import choose_speakers, load_dataset_sources, load_wav_mono
|
|
5
|
+
from .collate import CollateBatch, collate_dataset_items
|
|
5
6
|
from .template import TemplateDataset, TemplateSentence
|
|
7
|
+
from .librispeech import LibriSpeechDataset, LibriSpeechSentence
|
|
6
8
|
|
|
7
9
|
from .cmu_arctic import (
|
|
8
10
|
CmuArcticDataset,
|
|
9
11
|
CmuArcticSentence,
|
|
10
12
|
list_cmu_arctic_speakers,
|
|
11
|
-
load_wav_mono,
|
|
12
13
|
save_wav,
|
|
13
14
|
)
|
|
14
15
|
|
|
@@ -17,6 +18,9 @@ __all__ = [
|
|
|
17
18
|
"CmuArcticDataset",
|
|
18
19
|
"CmuArcticSentence",
|
|
19
20
|
"choose_speakers",
|
|
21
|
+
"DatasetItem",
|
|
22
|
+
"CollateBatch",
|
|
23
|
+
"collate_dataset_items",
|
|
20
24
|
"list_cmu_arctic_speakers",
|
|
21
25
|
"SentenceLike",
|
|
22
26
|
"load_dataset_sources",
|
|
@@ -24,4 +28,6 @@ __all__ = [
|
|
|
24
28
|
"save_wav",
|
|
25
29
|
"TemplateDataset",
|
|
26
30
|
"TemplateSentence",
|
|
31
|
+
"LibriSpeechDataset",
|
|
32
|
+
"LibriSpeechSentence",
|
|
27
33
|
]
|
torchrir/datasets/base.py
CHANGED
|
@@ -2,9 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
"""Dataset protocol definitions."""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional, Protocol, Sequence, Tuple
|
|
6
7
|
|
|
7
8
|
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class SentenceLike(Protocol):
|
|
@@ -14,14 +16,52 @@ class SentenceLike(Protocol):
|
|
|
14
16
|
text: str
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class DatasetItem:
|
|
21
|
+
"""Dataset item for DataLoader consumption."""
|
|
22
|
+
|
|
23
|
+
audio: torch.Tensor
|
|
24
|
+
sample_rate: int
|
|
25
|
+
utterance_id: str
|
|
26
|
+
text: Optional[str] = None
|
|
27
|
+
speaker: Optional[str] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseDataset(Dataset[DatasetItem]):
|
|
31
|
+
"""Base dataset class compatible with torch.utils.data.Dataset."""
|
|
32
|
+
|
|
33
|
+
_sentences_cache: Optional[list[SentenceLike]] = None
|
|
19
34
|
|
|
20
35
|
def list_speakers(self) -> list[str]:
|
|
21
36
|
"""Return available speaker IDs."""
|
|
37
|
+
raise NotImplementedError
|
|
22
38
|
|
|
23
39
|
def available_sentences(self) -> Sequence[SentenceLike]:
|
|
24
40
|
"""Return sentence entries that have audio available."""
|
|
41
|
+
raise NotImplementedError
|
|
25
42
|
|
|
26
43
|
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
27
44
|
"""Load audio for an utterance and return (audio, sample_rate)."""
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
def __len__(self) -> int:
|
|
48
|
+
return len(self._get_sentences())
|
|
49
|
+
|
|
50
|
+
def __getitem__(self, idx: int) -> DatasetItem:
|
|
51
|
+
sentences = self._get_sentences()
|
|
52
|
+
sentence = sentences[idx]
|
|
53
|
+
audio, sample_rate = self.load_wav(sentence.utterance_id)
|
|
54
|
+
speaker = getattr(self, "speaker", None)
|
|
55
|
+
text = getattr(sentence, "text", None)
|
|
56
|
+
return DatasetItem(
|
|
57
|
+
audio=audio,
|
|
58
|
+
sample_rate=sample_rate,
|
|
59
|
+
utterance_id=sentence.utterance_id,
|
|
60
|
+
text=text,
|
|
61
|
+
speaker=speaker,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _get_sentences(self) -> list[SentenceLike]:
|
|
65
|
+
if self._sentences_cache is None:
|
|
66
|
+
self._sentences_cache = list(self.available_sentences())
|
|
67
|
+
return self._sentences_cache
|
torchrir/datasets/cmu_arctic.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
"""CMU ARCTIC dataset helpers."""
|
|
4
4
|
|
|
5
|
+
import logging
|
|
5
6
|
import tarfile
|
|
6
7
|
import urllib.request
|
|
7
8
|
from dataclasses import dataclass
|
|
@@ -9,7 +10,9 @@ from pathlib import Path
|
|
|
9
10
|
from typing import List, Tuple
|
|
10
11
|
|
|
11
12
|
import torch
|
|
12
|
-
|
|
13
|
+
|
|
14
|
+
from .base import BaseDataset
|
|
15
|
+
from .utils import load_wav_mono
|
|
13
16
|
|
|
14
17
|
BASE_URL = "http://www.festvox.org/cmu_arctic/packed"
|
|
15
18
|
VALID_SPEAKERS = {
|
|
@@ -49,7 +52,7 @@ class CmuArcticSentence:
|
|
|
49
52
|
text: str
|
|
50
53
|
|
|
51
54
|
|
|
52
|
-
class CmuArcticDataset:
|
|
55
|
+
class CmuArcticDataset(BaseDataset):
|
|
53
56
|
"""CMU ARCTIC dataset loader.
|
|
54
57
|
|
|
55
58
|
Example:
|
|
@@ -191,23 +194,6 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
|
|
|
191
194
|
return utterance, text
|
|
192
195
|
|
|
193
196
|
|
|
194
|
-
def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
195
|
-
"""Load a wav file and return mono audio and sample rate.
|
|
196
|
-
|
|
197
|
-
Example:
|
|
198
|
-
>>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
|
|
199
|
-
"""
|
|
200
|
-
import soundfile as sf
|
|
201
|
-
|
|
202
|
-
audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
|
|
203
|
-
audio_t = torch.from_numpy(audio)
|
|
204
|
-
if audio_t.shape[1] > 1:
|
|
205
|
-
audio_t = audio_t.mean(dim=1)
|
|
206
|
-
else:
|
|
207
|
-
audio_t = audio_t.squeeze(1)
|
|
208
|
-
return audio_t, sample_rate
|
|
209
|
-
|
|
210
|
-
|
|
211
197
|
def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
|
|
212
198
|
"""Save a mono or multi-channel wav to disk.
|
|
213
199
|
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Collate helpers for DataLoader usage."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from .base import DatasetItem
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class CollateBatch:
|
|
16
|
+
"""Collated batch of dataset items.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
audio: Padded audio tensor of shape (batch, max_len).
|
|
20
|
+
lengths: Original lengths for each item.
|
|
21
|
+
sample_rate: Sample rate shared across the batch.
|
|
22
|
+
utterance_ids: Utterance IDs per item.
|
|
23
|
+
texts: Optional text per item.
|
|
24
|
+
speakers: Optional speaker IDs per item.
|
|
25
|
+
metadata: Optional per-item metadata (pass-through).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
audio: Tensor
|
|
29
|
+
lengths: Tensor
|
|
30
|
+
sample_rate: int
|
|
31
|
+
utterance_ids: list[str]
|
|
32
|
+
texts: list[Optional[str]]
|
|
33
|
+
speakers: list[Optional[str]]
|
|
34
|
+
metadata: Optional[list[Any]] = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def collate_dataset_items(
|
|
38
|
+
items: Iterable[DatasetItem],
|
|
39
|
+
*,
|
|
40
|
+
pad_value: float = 0.0,
|
|
41
|
+
keep_metadata: bool = False,
|
|
42
|
+
) -> CollateBatch:
|
|
43
|
+
"""Collate DatasetItem entries into a padded batch.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
items: Iterable of DatasetItem.
|
|
47
|
+
pad_value: Value used for padding.
|
|
48
|
+
keep_metadata: Preserve item-level metadata field if present.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
CollateBatch with padded audio and metadata lists.
|
|
52
|
+
"""
|
|
53
|
+
batch = list(items)
|
|
54
|
+
if not batch:
|
|
55
|
+
raise ValueError("collate_dataset_items received an empty batch")
|
|
56
|
+
|
|
57
|
+
sample_rate = batch[0].sample_rate
|
|
58
|
+
for item in batch[1:]:
|
|
59
|
+
if item.sample_rate != sample_rate:
|
|
60
|
+
raise ValueError("sample_rate must be consistent within a batch")
|
|
61
|
+
|
|
62
|
+
lengths = torch.tensor([item.audio.numel() for item in batch], dtype=torch.long)
|
|
63
|
+
max_len = int(lengths.max().item())
|
|
64
|
+
audio = torch.full(
|
|
65
|
+
(len(batch), max_len),
|
|
66
|
+
pad_value,
|
|
67
|
+
dtype=batch[0].audio.dtype,
|
|
68
|
+
device=batch[0].audio.device,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
for idx, item in enumerate(batch):
|
|
72
|
+
audio[idx, : item.audio.numel()] = item.audio
|
|
73
|
+
|
|
74
|
+
utterance_ids = [item.utterance_id for item in batch]
|
|
75
|
+
texts = [item.text for item in batch]
|
|
76
|
+
speakers = [item.speaker for item in batch]
|
|
77
|
+
|
|
78
|
+
metadata: Optional[list[Any]] = None
|
|
79
|
+
if keep_metadata:
|
|
80
|
+
metadata = [getattr(item, "metadata", None) for item in batch]
|
|
81
|
+
|
|
82
|
+
return CollateBatch(
|
|
83
|
+
audio=audio,
|
|
84
|
+
lengths=lengths,
|
|
85
|
+
sample_rate=sample_rate,
|
|
86
|
+
utterance_ids=utterance_ids,
|
|
87
|
+
texts=texts,
|
|
88
|
+
speakers=speakers,
|
|
89
|
+
metadata=metadata,
|
|
90
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""LibriSpeech dataset helpers."""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import tarfile
|
|
7
|
+
import urllib.request
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import List, Tuple
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .base import BaseDataset
|
|
15
|
+
from .utils import load_wav_mono
|
|
16
|
+
|
|
17
|
+
BASE_URL = "https://www.openslr.org/resources/12"
|
|
18
|
+
VALID_SUBSETS = {
|
|
19
|
+
"dev-clean",
|
|
20
|
+
"dev-other",
|
|
21
|
+
"test-clean",
|
|
22
|
+
"test-other",
|
|
23
|
+
"train-clean-100",
|
|
24
|
+
"train-clean-360",
|
|
25
|
+
"train-other-500",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LibriSpeechSentence:
|
|
33
|
+
"""Sentence metadata from LibriSpeech."""
|
|
34
|
+
|
|
35
|
+
utterance_id: str
|
|
36
|
+
text: str
|
|
37
|
+
speaker_id: str
|
|
38
|
+
chapter_id: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LibriSpeechDataset(BaseDataset):
|
|
42
|
+
"""LibriSpeech dataset loader.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> dataset = LibriSpeechDataset(Path("datasets/librispeech"), subset="train-clean-100", download=True)
|
|
46
|
+
>>> audio, fs = dataset.load_wav("103-1240-0000")
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self, root: Path, subset: str = "train-clean-100", download: bool = False
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Initialize a LibriSpeech dataset handle.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
root: Root directory where the dataset is stored.
|
|
56
|
+
subset: LibriSpeech subset name (e.g., "train-clean-100").
|
|
57
|
+
download: Download and extract if missing.
|
|
58
|
+
"""
|
|
59
|
+
if subset not in VALID_SUBSETS:
|
|
60
|
+
raise ValueError(f"unsupported subset: {subset}")
|
|
61
|
+
self.root = Path(root)
|
|
62
|
+
self.subset = subset
|
|
63
|
+
self._archive_name = f"{subset}.tar.gz"
|
|
64
|
+
self._base_dir = self.root / "LibriSpeech"
|
|
65
|
+
self._subset_dir = self._base_dir / subset
|
|
66
|
+
|
|
67
|
+
if download:
|
|
68
|
+
self._download_and_extract()
|
|
69
|
+
|
|
70
|
+
if not self._subset_dir.exists():
|
|
71
|
+
raise FileNotFoundError(
|
|
72
|
+
"dataset not found; run with download=True or place the archive under "
|
|
73
|
+
f"{self.root}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def list_speakers(self) -> List[str]:
|
|
77
|
+
"""Return available speaker IDs."""
|
|
78
|
+
if not self._subset_dir.exists():
|
|
79
|
+
return []
|
|
80
|
+
return sorted([p.name for p in self._subset_dir.iterdir() if p.is_dir()])
|
|
81
|
+
|
|
82
|
+
def available_sentences(self) -> List[LibriSpeechSentence]:
|
|
83
|
+
"""Return sentences that have a corresponding audio file."""
|
|
84
|
+
sentences: List[LibriSpeechSentence] = []
|
|
85
|
+
for trans_path in self._subset_dir.rglob("*.trans.txt"):
|
|
86
|
+
chapter_dir = trans_path.parent
|
|
87
|
+
speaker_id = chapter_dir.parent.name
|
|
88
|
+
chapter_id = chapter_dir.name
|
|
89
|
+
with trans_path.open("r", encoding="utf-8") as f:
|
|
90
|
+
for line in f:
|
|
91
|
+
line = line.strip()
|
|
92
|
+
if not line:
|
|
93
|
+
continue
|
|
94
|
+
utt_id, text = _parse_text_line(line)
|
|
95
|
+
wav_path = chapter_dir / f"{utt_id}.flac"
|
|
96
|
+
if wav_path.exists():
|
|
97
|
+
sentences.append(
|
|
98
|
+
LibriSpeechSentence(
|
|
99
|
+
utterance_id=utt_id,
|
|
100
|
+
text=text,
|
|
101
|
+
speaker_id=speaker_id,
|
|
102
|
+
chapter_id=chapter_id,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
return sentences
|
|
106
|
+
|
|
107
|
+
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
108
|
+
"""Load a mono wav for the given utterance ID."""
|
|
109
|
+
speaker_id, chapter_id, _ = utterance_id.split("-", 2)
|
|
110
|
+
path = self._subset_dir / speaker_id / chapter_id / f"{utterance_id}.flac"
|
|
111
|
+
return load_wav_mono(path)
|
|
112
|
+
|
|
113
|
+
def _download_and_extract(self) -> None:
|
|
114
|
+
"""Download and extract the subset archive if needed."""
|
|
115
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
archive_path = self.root / self._archive_name
|
|
117
|
+
url = f"{BASE_URL}/{self._archive_name}"
|
|
118
|
+
|
|
119
|
+
if not archive_path.exists():
|
|
120
|
+
logger.info("Downloading %s", url)
|
|
121
|
+
_download(url, archive_path)
|
|
122
|
+
if not self._subset_dir.exists():
|
|
123
|
+
logger.info("Extracting %s", archive_path)
|
|
124
|
+
try:
|
|
125
|
+
with tarfile.open(archive_path, "r:gz") as tar:
|
|
126
|
+
tar.extractall(self.root)
|
|
127
|
+
except (tarfile.ReadError, EOFError, OSError) as exc:
|
|
128
|
+
logger.warning("Extraction failed (%s); re-downloading.", exc)
|
|
129
|
+
if archive_path.exists():
|
|
130
|
+
archive_path.unlink()
|
|
131
|
+
_download(url, archive_path)
|
|
132
|
+
with tarfile.open(archive_path, "r:gz") as tar:
|
|
133
|
+
tar.extractall(self.root)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _download(url: str, dest: Path, retries: int = 1) -> None:
|
|
137
|
+
"""Download a file with retry and resume-safe temp file."""
|
|
138
|
+
for attempt in range(retries + 1):
|
|
139
|
+
try:
|
|
140
|
+
_stream_download(url, dest)
|
|
141
|
+
return
|
|
142
|
+
except Exception as exc:
|
|
143
|
+
if dest.exists():
|
|
144
|
+
dest.unlink()
|
|
145
|
+
if attempt >= retries:
|
|
146
|
+
raise
|
|
147
|
+
logger.warning("Download failed (%s); retrying...", exc)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _stream_download(url: str, dest: Path) -> None:
|
|
151
|
+
"""Stream a URL to disk with a progress indicator."""
|
|
152
|
+
tmp_path = dest.with_suffix(dest.suffix + ".part")
|
|
153
|
+
if tmp_path.exists():
|
|
154
|
+
tmp_path.unlink()
|
|
155
|
+
|
|
156
|
+
with urllib.request.urlopen(url) as response:
|
|
157
|
+
total = response.length or 0
|
|
158
|
+
downloaded = 0
|
|
159
|
+
chunk_size = 1024 * 1024
|
|
160
|
+
with tmp_path.open("wb") as f:
|
|
161
|
+
while True:
|
|
162
|
+
chunk = response.read(chunk_size)
|
|
163
|
+
if not chunk:
|
|
164
|
+
break
|
|
165
|
+
f.write(chunk)
|
|
166
|
+
downloaded += len(chunk)
|
|
167
|
+
if total > 0 and downloaded != total:
|
|
168
|
+
raise IOError(f"incomplete download: {downloaded} of {total} bytes")
|
|
169
|
+
tmp_path.replace(dest)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _parse_text_line(line: str) -> Tuple[str, str]:
|
|
173
|
+
"""Parse a LibriSpeech transcript line into (utterance_id, text)."""
|
|
174
|
+
left, _, right = line.partition(" ")
|
|
175
|
+
return left, right.strip()
|
torchrir/datasets/utils.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
"""Dataset-agnostic utilities."""
|
|
4
4
|
|
|
5
5
|
import random
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
from typing import Callable, List, Optional, Sequence, Tuple
|
|
7
8
|
|
|
8
9
|
import torch
|
|
@@ -94,3 +95,20 @@ def load_dataset_sources(
|
|
|
94
95
|
if fs is None:
|
|
95
96
|
raise RuntimeError("no audio loaded from dataset sources")
|
|
96
97
|
return stacked, int(fs), info
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
101
|
+
"""Load a wav/flac file and return mono audio and sample rate.
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
>>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
|
|
105
|
+
"""
|
|
106
|
+
import soundfile as sf
|
|
107
|
+
|
|
108
|
+
audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
|
|
109
|
+
audio_t = torch.from_numpy(audio)
|
|
110
|
+
if audio_t.shape[1] > 1:
|
|
111
|
+
audio_t = audio_t.mean(dim=1)
|
|
112
|
+
else:
|
|
113
|
+
audio_t = audio_t.squeeze(1)
|
|
114
|
+
return audio_t, sample_rate
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchrir
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: PyTorch-based room impulse response (RIR) simulation toolkit for static and dynamic scenes.
|
|
5
5
|
Project-URL: Repository, https://github.com/taishi-n/torchrir
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -59,7 +59,7 @@ For detailed documentation, see the docs under `docs/` and Read the Docs.
|
|
|
59
59
|
## Future Work
|
|
60
60
|
- Ray tracing backend: implement `RayTracingSimulator` with frequency-dependent absorption/scattering.
|
|
61
61
|
- CUDA-native acceleration: introduce dedicated CUDA kernels for large-scale RIR generation.
|
|
62
|
-
- Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`).
|
|
62
|
+
- Dataset expansion: add additional dataset integrations beyond CMU ARCTIC (see `TemplateDataset`), including torchaudio datasets (e.g., LibriSpeech, VCTK, LibriTTS, SpeechCommands, CommonVoice, GTZAN, MUSDB-HQ).
|
|
63
63
|
- Add regression tests comparing generated RIRs against gpuRIR outputs.
|
|
64
64
|
|
|
65
65
|
## Related Libraries
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
torchrir/__init__.py,sha256=
|
|
1
|
+
torchrir/__init__.py,sha256=MTlouAErvB7IyM4pcmhEN1U0KQsZBglWxUoHxnZDa5U,2615
|
|
2
2
|
torchrir/animation.py,sha256=x3Y-BLz3J6DQNmoDIjbMEgGfng2yavJFLyQEmRCSpQU,6391
|
|
3
3
|
torchrir/config.py,sha256=PsZdDIS3p4jepeNSHyd69aSD9QlOEdpG9v1SAXlZ_Fg,2295
|
|
4
|
-
torchrir/core.py,sha256=
|
|
4
|
+
torchrir/core.py,sha256=Ug5thts1rXvCpdq9twVHz72oWygmk5J6gliloozHKL4,31704
|
|
5
5
|
torchrir/directivity.py,sha256=v_t37YgeXF_IYzbnrk0TCs1npb_0yKR7zHiG8XV3V4w,1259
|
|
6
6
|
torchrir/dynamic.py,sha256=01JHMxhORdcz93J-YaMIeSLo7k2tHrZke8llPHHXwZg,2153
|
|
7
7
|
torchrir/logging_utils.py,sha256=s4jDSSDoHT0HKeplDUpGMsdeBij4eibLSpaaAPzkB68,2146
|
|
@@ -15,14 +15,16 @@ torchrir/scene_utils.py,sha256=2La5dtjxYdINX315VXRRJMJK9oaR2rY0xHmDLjZma8M,2140
|
|
|
15
15
|
torchrir/signal.py,sha256=M0BpKDBqrfOmCHIJ_dvl-C3uKdFpXLDqtSIU115jsME,8383
|
|
16
16
|
torchrir/simulators.py,sha256=NCl8Ptv2TGdBpNLwAb3nigT77On-BLIANtc2ivgKasw,3131
|
|
17
17
|
torchrir/utils.py,sha256=2oE-JzAtkW5qdRds2Y5R5lbSyNZl_9piFXd6xOLzjxM,10680
|
|
18
|
-
torchrir/datasets/__init__.py,sha256=
|
|
19
|
-
torchrir/datasets/base.py,sha256=
|
|
20
|
-
torchrir/datasets/cmu_arctic.py,sha256=
|
|
18
|
+
torchrir/datasets/__init__.py,sha256=NS4zQas9YdsuDv8KQtTKmIJmS6mxMRxQk2xGzglbgUw,853
|
|
19
|
+
torchrir/datasets/base.py,sha256=LfdXO-NGCBtzaAqAeVxo8XuV5ieU6Vl91woqAHymsT8,1970
|
|
20
|
+
torchrir/datasets/cmu_arctic.py,sha256=7IFv33RBBu044kTMO6nKUmziml2gjILUgnpL262rAU8,6593
|
|
21
|
+
torchrir/datasets/collate.py,sha256=gZfaHog0gtb8Avg6qsDZ1m4yoKkYkcuwmty1RtLYhhI,2542
|
|
22
|
+
torchrir/datasets/librispeech.py,sha256=XKlAm0Z0coipKuqR9Z8X8l9puXVYz7zb6yE3PCuMUrI,6019
|
|
21
23
|
torchrir/datasets/template.py,sha256=pHAKj5E7Gehfk9pqdTsFQjiDV1OK3hSZJIbYutd-E4c,2090
|
|
22
|
-
torchrir/datasets/utils.py,sha256=
|
|
23
|
-
torchrir-0.
|
|
24
|
-
torchrir-0.
|
|
25
|
-
torchrir-0.
|
|
26
|
-
torchrir-0.
|
|
27
|
-
torchrir-0.
|
|
28
|
-
torchrir-0.
|
|
24
|
+
torchrir/datasets/utils.py,sha256=OCYd7Dbr2hsqBbiHE1LHPMYdqwe2YfDw0tpRMfND0Og,3790
|
|
25
|
+
torchrir-0.2.0.dist-info/licenses/LICENSE,sha256=5vS_7WTsMEw_QQHEPQ_WCwovJXEgmxoEwcwOI-9VbXI,10766
|
|
26
|
+
torchrir-0.2.0.dist-info/licenses/NOTICE,sha256=SRs_q-ZqoVF9_YuuedZOvVBk01jV7YQAeF8rRvlRg0s,118
|
|
27
|
+
torchrir-0.2.0.dist-info/METADATA,sha256=0f5EhK2SWmP_RxZ2O8y7AIDP_dFaKaf9X_n0mt7045k,3077
|
|
28
|
+
torchrir-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
29
|
+
torchrir-0.2.0.dist-info/top_level.txt,sha256=aIFwntowJjvm7rZk480HymC3ipDo1g-9hEbNY1wF-Oo,9
|
|
30
|
+
torchrir-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|