torchrir 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,27 @@
1
+ """Dataset helpers for torchrir."""
2
+
3
+ from .base import BaseDataset, SentenceLike
4
+ from .utils import choose_speakers, load_dataset_sources
5
+ from .template import TemplateDataset, TemplateSentence
6
+
7
+ from .cmu_arctic import (
8
+ CmuArcticDataset,
9
+ CmuArcticSentence,
10
+ list_cmu_arctic_speakers,
11
+ load_wav_mono,
12
+ save_wav,
13
+ )
14
+
15
+ __all__ = [
16
+ "BaseDataset",
17
+ "CmuArcticDataset",
18
+ "CmuArcticSentence",
19
+ "choose_speakers",
20
+ "list_cmu_arctic_speakers",
21
+ "SentenceLike",
22
+ "load_dataset_sources",
23
+ "load_wav_mono",
24
+ "save_wav",
25
+ "TemplateDataset",
26
+ "TemplateSentence",
27
+ ]
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ """Dataset protocol definitions."""
4
+
5
+ from typing import Protocol, Sequence, Tuple
6
+
7
+ import torch
8
+
9
+
10
+ class SentenceLike(Protocol):
11
+ """Minimal sentence interface for dataset entries."""
12
+
13
+ utterance_id: str
14
+ text: str
15
+
16
+
17
+ class BaseDataset(Protocol):
18
+ """Protocol for datasets used in torchrir examples and tools."""
19
+
20
+ def list_speakers(self) -> list[str]:
21
+ """Return available speaker IDs."""
22
+
23
+ def available_sentences(self) -> Sequence[SentenceLike]:
24
+ """Return sentence entries that have audio available."""
25
+
26
+ def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
27
+ """Load audio for an utterance and return (audio, sample_rate)."""
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ """CMU ARCTIC dataset helpers."""
4
+
5
+ import tarfile
6
+ import urllib.request
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import logging
13
+
14
+ BASE_URL = "http://www.festvox.org/cmu_arctic/packed"
15
+ VALID_SPEAKERS = {
16
+ "aew",
17
+ "ahw",
18
+ "aup",
19
+ "awb",
20
+ "axb",
21
+ "bdl",
22
+ "clb",
23
+ "eey",
24
+ "fem",
25
+ "gka",
26
+ "jmk",
27
+ "ksp",
28
+ "ljm",
29
+ "lnh",
30
+ "rms",
31
+ "rxr",
32
+ "slp",
33
+ "slt",
34
+ }
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def list_cmu_arctic_speakers() -> List[str]:
40
+ """Return supported CMU ARCTIC speaker IDs."""
41
+ return sorted(VALID_SPEAKERS)
42
+
43
+
44
+ @dataclass
45
+ class CmuArcticSentence:
46
+ """Sentence metadata from CMU ARCTIC."""
47
+ utterance_id: str
48
+ text: str
49
+
50
+
51
+ class CmuArcticDataset:
52
+ def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
53
+ """Initialize a CMU ARCTIC dataset handle.
54
+
55
+ Args:
56
+ root: Root directory where the dataset is stored.
57
+ speaker: Speaker ID (e.g., "bdl").
58
+ download: Download and extract if missing.
59
+ """
60
+ if speaker not in VALID_SPEAKERS:
61
+ raise ValueError(f"unsupported speaker: {speaker}")
62
+ self.root = Path(root)
63
+ self.speaker = speaker
64
+ self._base_dir = self.root / "ARCTIC"
65
+ self._archive_name = f"cmu_us_{speaker}_arctic.tar.bz2"
66
+ self._dataset_dir = self._base_dir / f"cmu_us_{speaker}_arctic"
67
+
68
+ if download:
69
+ self._download_and_extract()
70
+
71
+ if not self._dataset_dir.exists():
72
+ raise FileNotFoundError(
73
+ "dataset not found; run with download=True or place the archive under "
74
+ f"{self._base_dir}"
75
+ )
76
+
77
+ @property
78
+ def wav_dir(self) -> Path:
79
+ """Return the directory containing wav files."""
80
+ return self._dataset_dir / "wav"
81
+
82
+ @property
83
+ def text_path(self) -> Path:
84
+ """Return the path to txt.done.data."""
85
+ return self._dataset_dir / "etc" / "txt.done.data"
86
+
87
+ def _download_and_extract(self) -> None:
88
+ """Download and extract the speaker archive if needed."""
89
+ self._base_dir.mkdir(parents=True, exist_ok=True)
90
+ archive_path = self._base_dir / self._archive_name
91
+ url = f"{BASE_URL}/{self._archive_name}"
92
+
93
+ if not archive_path.exists():
94
+ logger.info("Downloading %s", url)
95
+ _download(url, archive_path)
96
+ if not self._dataset_dir.exists():
97
+ logger.info("Extracting %s", archive_path)
98
+ try:
99
+ with tarfile.open(archive_path, "r:bz2") as tar:
100
+ tar.extractall(self._base_dir)
101
+ except (tarfile.ReadError, EOFError, OSError) as exc:
102
+ logger.warning("Extraction failed (%s); re-downloading.", exc)
103
+ if archive_path.exists():
104
+ archive_path.unlink()
105
+ _download(url, archive_path)
106
+ with tarfile.open(archive_path, "r:bz2") as tar:
107
+ tar.extractall(self._base_dir)
108
+
109
+ def sentences(self) -> List[CmuArcticSentence]:
110
+ """Parse all sentence metadata."""
111
+ sentences: List[CmuArcticSentence] = []
112
+ with self.text_path.open("r", encoding="utf-8") as f:
113
+ for line in f:
114
+ line = line.strip()
115
+ if not line:
116
+ continue
117
+ utt, text = _parse_text_line(line)
118
+ sentences.append(CmuArcticSentence(utterance_id=utt, text=text))
119
+ return sentences
120
+
121
+ def available_sentences(self) -> List[CmuArcticSentence]:
122
+ """Return sentences that have a corresponding wav file."""
123
+ wav_ids = {p.stem for p in self.wav_dir.glob("*.wav")}
124
+ return [s for s in self.sentences() if s.utterance_id in wav_ids]
125
+
126
+ def list_speakers(self) -> List[str]:
127
+ """Return available speaker IDs."""
128
+ return list_cmu_arctic_speakers()
129
+
130
+ def wav_path(self, utterance_id: str) -> Path:
131
+ """Return the wav path for an utterance ID."""
132
+ return self.wav_dir / f"{utterance_id}.wav"
133
+
134
+ def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
135
+ """Load a mono wav for the given utterance ID."""
136
+ path = self.wav_path(utterance_id)
137
+ return load_wav_mono(path)
138
+
139
+
140
+ def _download(url: str, dest: Path, retries: int = 1) -> None:
141
+ """Download a file with retry and resume-safe temp file."""
142
+ for attempt in range(retries + 1):
143
+ try:
144
+ _stream_download(url, dest)
145
+ return
146
+ except Exception as exc:
147
+ if dest.exists():
148
+ dest.unlink()
149
+ if attempt >= retries:
150
+ raise
151
+ logger.warning("Download failed (%s); retrying...", exc)
152
+
153
+
154
+ def _stream_download(url: str, dest: Path) -> None:
155
+ """Stream a URL to disk with a progress indicator."""
156
+ tmp_path = dest.with_suffix(dest.suffix + ".part")
157
+ if tmp_path.exists():
158
+ tmp_path.unlink()
159
+
160
+ with urllib.request.urlopen(url) as response:
161
+ total = response.length or 0
162
+ downloaded = 0
163
+ chunk_size = 1024 * 1024
164
+ with tmp_path.open("wb") as f:
165
+ while True:
166
+ chunk = response.read(chunk_size)
167
+ if not chunk:
168
+ break
169
+ f.write(chunk)
170
+ downloaded += len(chunk)
171
+ if total > 0 and downloaded != total:
172
+ raise IOError(f"incomplete download: {downloaded} of {total} bytes")
173
+ tmp_path.replace(dest)
174
+
175
+
176
+ def _parse_text_line(line: str) -> Tuple[str, str]:
177
+ """Parse a txt.done.data line into (utterance_id, text)."""
178
+ left, _, right = line.partition('"')
179
+ utterance = left.replace("(", "").strip().split()[0]
180
+ text = right.rsplit('"', 1)[0]
181
+ return utterance, text
182
+
183
+
184
+ def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
185
+ """Load a wav file and return mono audio and sample rate."""
186
+ import soundfile as sf
187
+
188
+ audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
189
+ audio_t = torch.from_numpy(audio)
190
+ if audio_t.shape[1] > 1:
191
+ audio_t = audio_t.mean(dim=1)
192
+ else:
193
+ audio_t = audio_t.squeeze(1)
194
+ return audio_t, sample_rate
195
+
196
+
197
+ def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
198
+ """Save a mono or multi-channel wav to disk."""
199
+ import soundfile as sf
200
+
201
+ audio = audio.detach().cpu().clamp(-1.0, 1.0).to(torch.float32)
202
+ if audio.ndim == 2 and audio.shape[0] <= 8:
203
+ audio = audio.transpose(0, 1)
204
+ sf.write(str(path), audio.numpy(), sample_rate)
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ """Dataset template for future extensions.
4
+
5
+ Work in progress:
6
+ This module is a placeholder for future dataset integrations. The goal is
7
+ to provide a consistent interface for downloading, caching, enumerating
8
+ speakers/utterances, and loading audio in a reproducible way.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import List, Sequence, Tuple
14
+
15
+ import torch
16
+
17
+ from .base import BaseDataset, SentenceLike
18
+
19
+
20
+ @dataclass
21
+ class TemplateSentence:
22
+ """Minimal sentence metadata for a template dataset."""
23
+
24
+ utterance_id: str
25
+ text: str
26
+
27
+
28
+ class TemplateDataset(BaseDataset):
29
+ """Work in progress template dataset implementation.
30
+
31
+ Goal:
32
+ Implement concrete dataset handlers by filling in download logic,
33
+ metadata parsing, and audio loading while keeping the BaseDataset
34
+ protocol intact.
35
+ """
36
+
37
+ def __init__(self, root: Path, speaker: str = "default", download: bool = False) -> None:
38
+ self.root = Path(root)
39
+ self.speaker = speaker
40
+ if download:
41
+ raise NotImplementedError(
42
+ "download is not implemented yet. Intended to fetch and cache "
43
+ "dataset archives under root."
44
+ )
45
+
46
+ def list_speakers(self) -> List[str]:
47
+ """Return available speaker IDs."""
48
+ return ["default"]
49
+
50
+ def available_sentences(self) -> Sequence[SentenceLike]:
51
+ """Return sentence entries that have audio available.
52
+
53
+ Work in progress:
54
+ Intended to parse dataset metadata and filter to utterances that
55
+ have corresponding audio files on disk.
56
+ """
57
+ raise NotImplementedError("available_sentences is not implemented yet")
58
+
59
+ def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
60
+ """Load audio for an utterance and return (audio, sample_rate).
61
+
62
+ Work in progress:
63
+ Intended to load audio from local cache and return mono float32.
64
+ """
65
+ raise NotImplementedError("load_wav is not implemented yet")
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ """Dataset-agnostic utilities."""
4
+
5
+ import random
6
+ from typing import Callable, List, Optional, Sequence, Tuple
7
+
8
+ import torch
9
+
10
+ from .base import BaseDataset, SentenceLike
11
+
12
+
13
+ def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
14
+ """Select unique speakers for the requested number of sources."""
15
+ speakers = dataset.list_speakers()
16
+ if not speakers:
17
+ raise RuntimeError("no speakers available")
18
+ if num_sources > len(speakers):
19
+ raise ValueError(f"num_sources must be <= {len(speakers)} for unique speakers")
20
+ return rng.sample(speakers, num_sources)
21
+
22
+
23
+ def load_dataset_sources(
24
+ *,
25
+ dataset_factory: Callable[[Optional[str]], BaseDataset],
26
+ num_sources: int,
27
+ duration_s: float,
28
+ rng: random.Random,
29
+ ) -> Tuple[torch.Tensor, int, List[Tuple[str, List[str]]]]:
30
+ """Load and concatenate utterances for each speaker into fixed-length signals."""
31
+ dataset0 = dataset_factory(None)
32
+ speakers = choose_speakers(dataset0, num_sources, rng)
33
+ signals: List[torch.Tensor] = []
34
+ info: List[Tuple[str, List[str]]] = []
35
+ fs: int | None = None
36
+ target_samples: int | None = None
37
+
38
+ for speaker in speakers:
39
+ dataset = dataset_factory(speaker)
40
+ sentences: Sequence[SentenceLike] = dataset.available_sentences()
41
+ if not sentences:
42
+ raise RuntimeError(f"no sentences found for speaker {speaker}")
43
+
44
+ utterance_ids: List[str] = []
45
+ segments: List[torch.Tensor] = []
46
+ total = 0
47
+ sentences = list(sentences)
48
+ rng.shuffle(sentences)
49
+ idx = 0
50
+
51
+ while target_samples is None or total < target_samples:
52
+ if idx >= len(sentences):
53
+ rng.shuffle(sentences)
54
+ idx = 0
55
+ sentence = sentences[idx]
56
+ idx += 1
57
+ audio, sample_rate = dataset.load_wav(sentence.utterance_id)
58
+ if fs is None:
59
+ fs = sample_rate
60
+ target_samples = int(duration_s * fs)
61
+ elif sample_rate != fs:
62
+ raise ValueError(
63
+ f"sample rate mismatch: expected {fs}, got {sample_rate} for {speaker}"
64
+ )
65
+ segments.append(audio)
66
+ utterance_ids.append(sentence.utterance_id)
67
+ total += audio.numel()
68
+
69
+ signal = torch.cat(segments, dim=0)[:target_samples]
70
+ signals.append(signal)
71
+ info.append((speaker, utterance_ids))
72
+
73
+ stacked = torch.stack(signals, dim=0)
74
+ return stacked, int(fs), info
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ """Directivity pattern utilities."""
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ def directivity_gain(pattern: str, cos_theta: Tensor) -> Tensor:
10
+ """Compute directivity gain for a pattern given cos(theta)."""
11
+ pattern = pattern.lower()
12
+ if pattern in ("omni", "omnidirectional"):
13
+ return torch.ones_like(cos_theta)
14
+ if pattern in ("homni", "halfomni", "half-omni"):
15
+ return (cos_theta > 0).to(cos_theta.dtype)
16
+ if pattern in ("subcardioid", "subcard"):
17
+ return 0.75 + 0.25 * cos_theta
18
+ if pattern in ("cardioid", "card"):
19
+ return 0.5 + 0.5 * cos_theta
20
+ if pattern in ("hypercardioid", "hypcard"):
21
+ return 0.25 + 0.75 * cos_theta
22
+ if pattern in ("bidir", "bidirectional", "figure8", "figure-8"):
23
+ return cos_theta
24
+ raise ValueError(f"unsupported directivity pattern: {pattern}")
25
+
26
+
27
+ def split_directivity(directivity: str | tuple[str, str]) -> tuple[str, str]:
28
+ """Normalize directivity specification into (source, mic)."""
29
+ if isinstance(directivity, (list, tuple)):
30
+ if len(directivity) != 2:
31
+ raise ValueError("directivity tuple must have length 2")
32
+ return directivity[0], directivity[1]
33
+ return directivity, directivity
torchrir/dynamic.py ADDED
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ """Dynamic convolution utilities.
4
+
5
+ DynamicConvolver is the public API for time-varying convolution. Lower-level
6
+ helpers live in signal.py and are not part of the stable surface.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from torch import Tensor
14
+
15
+ from .signal import _ensure_dynamic_rirs, _ensure_signal
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class DynamicConvolver:
20
+ """Convolver for time-varying RIRs."""
21
+
22
+ mode: str = "trajectory"
23
+ hop: Optional[int] = None
24
+ timestamps: Optional[Tensor] = None
25
+ fs: Optional[float] = None
26
+
27
+ def __call__(self, signal: Tensor, rirs: Tensor) -> Tensor:
28
+ return self.convolve(signal, rirs)
29
+
30
+ def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
31
+ """Convolve signals with time-varying RIRs."""
32
+ if self.mode not in ("trajectory", "hop"):
33
+ raise ValueError("mode must be 'trajectory' or 'hop'")
34
+ if self.mode == "hop":
35
+ if self.hop is None:
36
+ raise ValueError("hop must be provided for hop mode")
37
+ return _convolve_dynamic_hop(signal, rirs, self.hop)
38
+ return _convolve_dynamic_trajectory(signal, rirs, timestamps=self.timestamps, fs=self.fs)
39
+
40
+
41
+ def _convolve_dynamic_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
42
+ from .signal import _convolve_dynamic_rir_hop
43
+
44
+ signal = _ensure_signal(signal)
45
+ rirs = _ensure_dynamic_rirs(rirs, signal)
46
+ return _convolve_dynamic_rir_hop(signal, rirs, hop)
47
+
48
+
49
+ def _convolve_dynamic_trajectory(
50
+ signal: Tensor,
51
+ rirs: Tensor,
52
+ *,
53
+ timestamps: Optional[Tensor],
54
+ fs: Optional[float],
55
+ ) -> Tensor:
56
+ from .signal import _convolve_dynamic_rir_trajectory
57
+
58
+ signal = _ensure_signal(signal)
59
+ rirs = _ensure_dynamic_rirs(rirs, signal)
60
+ return _convolve_dynamic_rir_trajectory(signal, rirs, timestamps=timestamps, fs=fs)
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ """Logging helpers for torchrir."""
4
+
5
+ from dataclasses import dataclass, replace
6
+ import logging
7
+ from typing import Optional
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class LoggingConfig:
12
+ """Configuration for torchrir logging."""
13
+
14
+ level: str | int = "INFO"
15
+ format: str = "%(levelname)s:%(name)s:%(message)s"
16
+ datefmt: Optional[str] = None
17
+ propagate: bool = False
18
+
19
+ def resolve_level(self) -> int:
20
+ """Resolve level to a logging integer constant."""
21
+ if isinstance(self.level, int):
22
+ return self.level
23
+ if not isinstance(self.level, str):
24
+ raise TypeError("level must be str or int")
25
+ name = self.level.upper()
26
+ if name not in logging._nameToLevel:
27
+ raise ValueError(f"unknown log level: {self.level}")
28
+ return logging._nameToLevel[name]
29
+
30
+ def replace(self, **kwargs) -> "LoggingConfig":
31
+ """Return a new config with updated fields."""
32
+ return replace(self, **kwargs)
33
+
34
+
35
+ def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.Logger:
36
+ """Configure and return the base torchrir logger."""
37
+ logger = logging.getLogger(name)
38
+ level = config.resolve_level()
39
+ logger.setLevel(level)
40
+ logger.propagate = config.propagate
41
+ if not logger.handlers:
42
+ handler = logging.StreamHandler()
43
+ handler.setLevel(level)
44
+ handler.setFormatter(logging.Formatter(config.format, datefmt=config.datefmt))
45
+ logger.addHandler(handler)
46
+ return logger
47
+
48
+
49
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
50
+ """Return a torchrir logger, namespaced under the torchrir root."""
51
+ if not name:
52
+ return logging.getLogger("torchrir")
53
+ if name.startswith("torchrir"):
54
+ return logging.getLogger(name)
55
+ return logging.getLogger(f"torchrir.{name}")