torchrir 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +85 -0
- torchrir/config.py +59 -0
- torchrir/core.py +741 -0
- torchrir/datasets/__init__.py +27 -0
- torchrir/datasets/base.py +27 -0
- torchrir/datasets/cmu_arctic.py +204 -0
- torchrir/datasets/template.py +65 -0
- torchrir/datasets/utils.py +74 -0
- torchrir/directivity.py +33 -0
- torchrir/dynamic.py +60 -0
- torchrir/logging_utils.py +55 -0
- torchrir/plotting.py +210 -0
- torchrir/plotting_utils.py +173 -0
- torchrir/results.py +22 -0
- torchrir/room.py +150 -0
- torchrir/scene.py +67 -0
- torchrir/scene_utils.py +51 -0
- torchrir/signal.py +233 -0
- torchrir/simulators.py +86 -0
- torchrir/utils.py +281 -0
- torchrir-0.1.0.dist-info/METADATA +213 -0
- torchrir-0.1.0.dist-info/RECORD +26 -0
- torchrir-0.1.0.dist-info/WHEEL +5 -0
- torchrir-0.1.0.dist-info/licenses/LICENSE +190 -0
- torchrir-0.1.0.dist-info/licenses/NOTICE +4 -0
- torchrir-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Dataset helpers for torchrir."""
|
|
2
|
+
|
|
3
|
+
from .base import BaseDataset, SentenceLike
|
|
4
|
+
from .utils import choose_speakers, load_dataset_sources
|
|
5
|
+
from .template import TemplateDataset, TemplateSentence
|
|
6
|
+
|
|
7
|
+
from .cmu_arctic import (
|
|
8
|
+
CmuArcticDataset,
|
|
9
|
+
CmuArcticSentence,
|
|
10
|
+
list_cmu_arctic_speakers,
|
|
11
|
+
load_wav_mono,
|
|
12
|
+
save_wav,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BaseDataset",
|
|
17
|
+
"CmuArcticDataset",
|
|
18
|
+
"CmuArcticSentence",
|
|
19
|
+
"choose_speakers",
|
|
20
|
+
"list_cmu_arctic_speakers",
|
|
21
|
+
"SentenceLike",
|
|
22
|
+
"load_dataset_sources",
|
|
23
|
+
"load_wav_mono",
|
|
24
|
+
"save_wav",
|
|
25
|
+
"TemplateDataset",
|
|
26
|
+
"TemplateSentence",
|
|
27
|
+
]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Dataset protocol definitions."""
|
|
4
|
+
|
|
5
|
+
from typing import Protocol, Sequence, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SentenceLike(Protocol):
|
|
11
|
+
"""Minimal sentence interface for dataset entries."""
|
|
12
|
+
|
|
13
|
+
utterance_id: str
|
|
14
|
+
text: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseDataset(Protocol):
|
|
18
|
+
"""Protocol for datasets used in torchrir examples and tools."""
|
|
19
|
+
|
|
20
|
+
def list_speakers(self) -> list[str]:
|
|
21
|
+
"""Return available speaker IDs."""
|
|
22
|
+
|
|
23
|
+
def available_sentences(self) -> Sequence[SentenceLike]:
|
|
24
|
+
"""Return sentence entries that have audio available."""
|
|
25
|
+
|
|
26
|
+
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
27
|
+
"""Load audio for an utterance and return (audio, sample_rate)."""
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""CMU ARCTIC dataset helpers."""
|
|
4
|
+
|
|
5
|
+
import tarfile
|
|
6
|
+
import urllib.request
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import List, Tuple
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
BASE_URL = "http://www.festvox.org/cmu_arctic/packed"
|
|
15
|
+
VALID_SPEAKERS = {
|
|
16
|
+
"aew",
|
|
17
|
+
"ahw",
|
|
18
|
+
"aup",
|
|
19
|
+
"awb",
|
|
20
|
+
"axb",
|
|
21
|
+
"bdl",
|
|
22
|
+
"clb",
|
|
23
|
+
"eey",
|
|
24
|
+
"fem",
|
|
25
|
+
"gka",
|
|
26
|
+
"jmk",
|
|
27
|
+
"ksp",
|
|
28
|
+
"ljm",
|
|
29
|
+
"lnh",
|
|
30
|
+
"rms",
|
|
31
|
+
"rxr",
|
|
32
|
+
"slp",
|
|
33
|
+
"slt",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def list_cmu_arctic_speakers() -> List[str]:
|
|
40
|
+
"""Return supported CMU ARCTIC speaker IDs."""
|
|
41
|
+
return sorted(VALID_SPEAKERS)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class CmuArcticSentence:
|
|
46
|
+
"""Sentence metadata from CMU ARCTIC."""
|
|
47
|
+
utterance_id: str
|
|
48
|
+
text: str
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CmuArcticDataset:
|
|
52
|
+
def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
|
|
53
|
+
"""Initialize a CMU ARCTIC dataset handle.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
root: Root directory where the dataset is stored.
|
|
57
|
+
speaker: Speaker ID (e.g., "bdl").
|
|
58
|
+
download: Download and extract if missing.
|
|
59
|
+
"""
|
|
60
|
+
if speaker not in VALID_SPEAKERS:
|
|
61
|
+
raise ValueError(f"unsupported speaker: {speaker}")
|
|
62
|
+
self.root = Path(root)
|
|
63
|
+
self.speaker = speaker
|
|
64
|
+
self._base_dir = self.root / "ARCTIC"
|
|
65
|
+
self._archive_name = f"cmu_us_{speaker}_arctic.tar.bz2"
|
|
66
|
+
self._dataset_dir = self._base_dir / f"cmu_us_{speaker}_arctic"
|
|
67
|
+
|
|
68
|
+
if download:
|
|
69
|
+
self._download_and_extract()
|
|
70
|
+
|
|
71
|
+
if not self._dataset_dir.exists():
|
|
72
|
+
raise FileNotFoundError(
|
|
73
|
+
"dataset not found; run with download=True or place the archive under "
|
|
74
|
+
f"{self._base_dir}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def wav_dir(self) -> Path:
|
|
79
|
+
"""Return the directory containing wav files."""
|
|
80
|
+
return self._dataset_dir / "wav"
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def text_path(self) -> Path:
|
|
84
|
+
"""Return the path to txt.done.data."""
|
|
85
|
+
return self._dataset_dir / "etc" / "txt.done.data"
|
|
86
|
+
|
|
87
|
+
def _download_and_extract(self) -> None:
|
|
88
|
+
"""Download and extract the speaker archive if needed."""
|
|
89
|
+
self._base_dir.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
archive_path = self._base_dir / self._archive_name
|
|
91
|
+
url = f"{BASE_URL}/{self._archive_name}"
|
|
92
|
+
|
|
93
|
+
if not archive_path.exists():
|
|
94
|
+
logger.info("Downloading %s", url)
|
|
95
|
+
_download(url, archive_path)
|
|
96
|
+
if not self._dataset_dir.exists():
|
|
97
|
+
logger.info("Extracting %s", archive_path)
|
|
98
|
+
try:
|
|
99
|
+
with tarfile.open(archive_path, "r:bz2") as tar:
|
|
100
|
+
tar.extractall(self._base_dir)
|
|
101
|
+
except (tarfile.ReadError, EOFError, OSError) as exc:
|
|
102
|
+
logger.warning("Extraction failed (%s); re-downloading.", exc)
|
|
103
|
+
if archive_path.exists():
|
|
104
|
+
archive_path.unlink()
|
|
105
|
+
_download(url, archive_path)
|
|
106
|
+
with tarfile.open(archive_path, "r:bz2") as tar:
|
|
107
|
+
tar.extractall(self._base_dir)
|
|
108
|
+
|
|
109
|
+
def sentences(self) -> List[CmuArcticSentence]:
|
|
110
|
+
"""Parse all sentence metadata."""
|
|
111
|
+
sentences: List[CmuArcticSentence] = []
|
|
112
|
+
with self.text_path.open("r", encoding="utf-8") as f:
|
|
113
|
+
for line in f:
|
|
114
|
+
line = line.strip()
|
|
115
|
+
if not line:
|
|
116
|
+
continue
|
|
117
|
+
utt, text = _parse_text_line(line)
|
|
118
|
+
sentences.append(CmuArcticSentence(utterance_id=utt, text=text))
|
|
119
|
+
return sentences
|
|
120
|
+
|
|
121
|
+
def available_sentences(self) -> List[CmuArcticSentence]:
|
|
122
|
+
"""Return sentences that have a corresponding wav file."""
|
|
123
|
+
wav_ids = {p.stem for p in self.wav_dir.glob("*.wav")}
|
|
124
|
+
return [s for s in self.sentences() if s.utterance_id in wav_ids]
|
|
125
|
+
|
|
126
|
+
def list_speakers(self) -> List[str]:
|
|
127
|
+
"""Return available speaker IDs."""
|
|
128
|
+
return list_cmu_arctic_speakers()
|
|
129
|
+
|
|
130
|
+
def wav_path(self, utterance_id: str) -> Path:
|
|
131
|
+
"""Return the wav path for an utterance ID."""
|
|
132
|
+
return self.wav_dir / f"{utterance_id}.wav"
|
|
133
|
+
|
|
134
|
+
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
135
|
+
"""Load a mono wav for the given utterance ID."""
|
|
136
|
+
path = self.wav_path(utterance_id)
|
|
137
|
+
return load_wav_mono(path)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _download(url: str, dest: Path, retries: int = 1) -> None:
|
|
141
|
+
"""Download a file with retry and resume-safe temp file."""
|
|
142
|
+
for attempt in range(retries + 1):
|
|
143
|
+
try:
|
|
144
|
+
_stream_download(url, dest)
|
|
145
|
+
return
|
|
146
|
+
except Exception as exc:
|
|
147
|
+
if dest.exists():
|
|
148
|
+
dest.unlink()
|
|
149
|
+
if attempt >= retries:
|
|
150
|
+
raise
|
|
151
|
+
logger.warning("Download failed (%s); retrying...", exc)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _stream_download(url: str, dest: Path) -> None:
|
|
155
|
+
"""Stream a URL to disk with a progress indicator."""
|
|
156
|
+
tmp_path = dest.with_suffix(dest.suffix + ".part")
|
|
157
|
+
if tmp_path.exists():
|
|
158
|
+
tmp_path.unlink()
|
|
159
|
+
|
|
160
|
+
with urllib.request.urlopen(url) as response:
|
|
161
|
+
total = response.length or 0
|
|
162
|
+
downloaded = 0
|
|
163
|
+
chunk_size = 1024 * 1024
|
|
164
|
+
with tmp_path.open("wb") as f:
|
|
165
|
+
while True:
|
|
166
|
+
chunk = response.read(chunk_size)
|
|
167
|
+
if not chunk:
|
|
168
|
+
break
|
|
169
|
+
f.write(chunk)
|
|
170
|
+
downloaded += len(chunk)
|
|
171
|
+
if total > 0 and downloaded != total:
|
|
172
|
+
raise IOError(f"incomplete download: {downloaded} of {total} bytes")
|
|
173
|
+
tmp_path.replace(dest)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _parse_text_line(line: str) -> Tuple[str, str]:
|
|
177
|
+
"""Parse a txt.done.data line into (utterance_id, text)."""
|
|
178
|
+
left, _, right = line.partition('"')
|
|
179
|
+
utterance = left.replace("(", "").strip().split()[0]
|
|
180
|
+
text = right.rsplit('"', 1)[0]
|
|
181
|
+
return utterance, text
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
185
|
+
"""Load a wav file and return mono audio and sample rate."""
|
|
186
|
+
import soundfile as sf
|
|
187
|
+
|
|
188
|
+
audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
|
|
189
|
+
audio_t = torch.from_numpy(audio)
|
|
190
|
+
if audio_t.shape[1] > 1:
|
|
191
|
+
audio_t = audio_t.mean(dim=1)
|
|
192
|
+
else:
|
|
193
|
+
audio_t = audio_t.squeeze(1)
|
|
194
|
+
return audio_t, sample_rate
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
|
|
198
|
+
"""Save a mono or multi-channel wav to disk."""
|
|
199
|
+
import soundfile as sf
|
|
200
|
+
|
|
201
|
+
audio = audio.detach().cpu().clamp(-1.0, 1.0).to(torch.float32)
|
|
202
|
+
if audio.ndim == 2 and audio.shape[0] <= 8:
|
|
203
|
+
audio = audio.transpose(0, 1)
|
|
204
|
+
sf.write(str(path), audio.numpy(), sample_rate)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Dataset template for future extensions.
|
|
4
|
+
|
|
5
|
+
Work in progress:
|
|
6
|
+
This module is a placeholder for future dataset integrations. The goal is
|
|
7
|
+
to provide a consistent interface for downloading, caching, enumerating
|
|
8
|
+
speakers/utterances, and loading audio in a reproducible way.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import List, Sequence, Tuple
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from .base import BaseDataset, SentenceLike
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class TemplateSentence:
|
|
22
|
+
"""Minimal sentence metadata for a template dataset."""
|
|
23
|
+
|
|
24
|
+
utterance_id: str
|
|
25
|
+
text: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TemplateDataset(BaseDataset):
|
|
29
|
+
"""Work in progress template dataset implementation.
|
|
30
|
+
|
|
31
|
+
Goal:
|
|
32
|
+
Implement concrete dataset handlers by filling in download logic,
|
|
33
|
+
metadata parsing, and audio loading while keeping the BaseDataset
|
|
34
|
+
protocol intact.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, root: Path, speaker: str = "default", download: bool = False) -> None:
|
|
38
|
+
self.root = Path(root)
|
|
39
|
+
self.speaker = speaker
|
|
40
|
+
if download:
|
|
41
|
+
raise NotImplementedError(
|
|
42
|
+
"download is not implemented yet. Intended to fetch and cache "
|
|
43
|
+
"dataset archives under root."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def list_speakers(self) -> List[str]:
|
|
47
|
+
"""Return available speaker IDs."""
|
|
48
|
+
return ["default"]
|
|
49
|
+
|
|
50
|
+
def available_sentences(self) -> Sequence[SentenceLike]:
|
|
51
|
+
"""Return sentence entries that have audio available.
|
|
52
|
+
|
|
53
|
+
Work in progress:
|
|
54
|
+
Intended to parse dataset metadata and filter to utterances that
|
|
55
|
+
have corresponding audio files on disk.
|
|
56
|
+
"""
|
|
57
|
+
raise NotImplementedError("available_sentences is not implemented yet")
|
|
58
|
+
|
|
59
|
+
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
60
|
+
"""Load audio for an utterance and return (audio, sample_rate).
|
|
61
|
+
|
|
62
|
+
Work in progress:
|
|
63
|
+
Intended to load audio from local cache and return mono float32.
|
|
64
|
+
"""
|
|
65
|
+
raise NotImplementedError("load_wav is not implemented yet")
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Dataset-agnostic utilities."""
|
|
4
|
+
|
|
5
|
+
import random
|
|
6
|
+
from typing import Callable, List, Optional, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .base import BaseDataset, SentenceLike
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
|
|
14
|
+
"""Select unique speakers for the requested number of sources."""
|
|
15
|
+
speakers = dataset.list_speakers()
|
|
16
|
+
if not speakers:
|
|
17
|
+
raise RuntimeError("no speakers available")
|
|
18
|
+
if num_sources > len(speakers):
|
|
19
|
+
raise ValueError(f"num_sources must be <= {len(speakers)} for unique speakers")
|
|
20
|
+
return rng.sample(speakers, num_sources)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_dataset_sources(
|
|
24
|
+
*,
|
|
25
|
+
dataset_factory: Callable[[Optional[str]], BaseDataset],
|
|
26
|
+
num_sources: int,
|
|
27
|
+
duration_s: float,
|
|
28
|
+
rng: random.Random,
|
|
29
|
+
) -> Tuple[torch.Tensor, int, List[Tuple[str, List[str]]]]:
|
|
30
|
+
"""Load and concatenate utterances for each speaker into fixed-length signals."""
|
|
31
|
+
dataset0 = dataset_factory(None)
|
|
32
|
+
speakers = choose_speakers(dataset0, num_sources, rng)
|
|
33
|
+
signals: List[torch.Tensor] = []
|
|
34
|
+
info: List[Tuple[str, List[str]]] = []
|
|
35
|
+
fs: int | None = None
|
|
36
|
+
target_samples: int | None = None
|
|
37
|
+
|
|
38
|
+
for speaker in speakers:
|
|
39
|
+
dataset = dataset_factory(speaker)
|
|
40
|
+
sentences: Sequence[SentenceLike] = dataset.available_sentences()
|
|
41
|
+
if not sentences:
|
|
42
|
+
raise RuntimeError(f"no sentences found for speaker {speaker}")
|
|
43
|
+
|
|
44
|
+
utterance_ids: List[str] = []
|
|
45
|
+
segments: List[torch.Tensor] = []
|
|
46
|
+
total = 0
|
|
47
|
+
sentences = list(sentences)
|
|
48
|
+
rng.shuffle(sentences)
|
|
49
|
+
idx = 0
|
|
50
|
+
|
|
51
|
+
while target_samples is None or total < target_samples:
|
|
52
|
+
if idx >= len(sentences):
|
|
53
|
+
rng.shuffle(sentences)
|
|
54
|
+
idx = 0
|
|
55
|
+
sentence = sentences[idx]
|
|
56
|
+
idx += 1
|
|
57
|
+
audio, sample_rate = dataset.load_wav(sentence.utterance_id)
|
|
58
|
+
if fs is None:
|
|
59
|
+
fs = sample_rate
|
|
60
|
+
target_samples = int(duration_s * fs)
|
|
61
|
+
elif sample_rate != fs:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"sample rate mismatch: expected {fs}, got {sample_rate} for {speaker}"
|
|
64
|
+
)
|
|
65
|
+
segments.append(audio)
|
|
66
|
+
utterance_ids.append(sentence.utterance_id)
|
|
67
|
+
total += audio.numel()
|
|
68
|
+
|
|
69
|
+
signal = torch.cat(segments, dim=0)[:target_samples]
|
|
70
|
+
signals.append(signal)
|
|
71
|
+
info.append((speaker, utterance_ids))
|
|
72
|
+
|
|
73
|
+
stacked = torch.stack(signals, dim=0)
|
|
74
|
+
return stacked, int(fs), info
|
torchrir/directivity.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Directivity pattern utilities."""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def directivity_gain(pattern: str, cos_theta: Tensor) -> Tensor:
|
|
10
|
+
"""Compute directivity gain for a pattern given cos(theta)."""
|
|
11
|
+
pattern = pattern.lower()
|
|
12
|
+
if pattern in ("omni", "omnidirectional"):
|
|
13
|
+
return torch.ones_like(cos_theta)
|
|
14
|
+
if pattern in ("homni", "halfomni", "half-omni"):
|
|
15
|
+
return (cos_theta > 0).to(cos_theta.dtype)
|
|
16
|
+
if pattern in ("subcardioid", "subcard"):
|
|
17
|
+
return 0.75 + 0.25 * cos_theta
|
|
18
|
+
if pattern in ("cardioid", "card"):
|
|
19
|
+
return 0.5 + 0.5 * cos_theta
|
|
20
|
+
if pattern in ("hypercardioid", "hypcard"):
|
|
21
|
+
return 0.25 + 0.75 * cos_theta
|
|
22
|
+
if pattern in ("bidir", "bidirectional", "figure8", "figure-8"):
|
|
23
|
+
return cos_theta
|
|
24
|
+
raise ValueError(f"unsupported directivity pattern: {pattern}")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def split_directivity(directivity: str | tuple[str, str]) -> tuple[str, str]:
|
|
28
|
+
"""Normalize directivity specification into (source, mic)."""
|
|
29
|
+
if isinstance(directivity, (list, tuple)):
|
|
30
|
+
if len(directivity) != 2:
|
|
31
|
+
raise ValueError("directivity tuple must have length 2")
|
|
32
|
+
return directivity[0], directivity[1]
|
|
33
|
+
return directivity, directivity
|
torchrir/dynamic.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Dynamic convolution utilities.
|
|
4
|
+
|
|
5
|
+
DynamicConvolver is the public API for time-varying convolution. Lower-level
|
|
6
|
+
helpers live in signal.py and are not part of the stable surface.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
|
|
15
|
+
from .signal import _ensure_dynamic_rirs, _ensure_signal
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class DynamicConvolver:
|
|
20
|
+
"""Convolver for time-varying RIRs."""
|
|
21
|
+
|
|
22
|
+
mode: str = "trajectory"
|
|
23
|
+
hop: Optional[int] = None
|
|
24
|
+
timestamps: Optional[Tensor] = None
|
|
25
|
+
fs: Optional[float] = None
|
|
26
|
+
|
|
27
|
+
def __call__(self, signal: Tensor, rirs: Tensor) -> Tensor:
|
|
28
|
+
return self.convolve(signal, rirs)
|
|
29
|
+
|
|
30
|
+
def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
|
|
31
|
+
"""Convolve signals with time-varying RIRs."""
|
|
32
|
+
if self.mode not in ("trajectory", "hop"):
|
|
33
|
+
raise ValueError("mode must be 'trajectory' or 'hop'")
|
|
34
|
+
if self.mode == "hop":
|
|
35
|
+
if self.hop is None:
|
|
36
|
+
raise ValueError("hop must be provided for hop mode")
|
|
37
|
+
return _convolve_dynamic_hop(signal, rirs, self.hop)
|
|
38
|
+
return _convolve_dynamic_trajectory(signal, rirs, timestamps=self.timestamps, fs=self.fs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _convolve_dynamic_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
|
|
42
|
+
from .signal import _convolve_dynamic_rir_hop
|
|
43
|
+
|
|
44
|
+
signal = _ensure_signal(signal)
|
|
45
|
+
rirs = _ensure_dynamic_rirs(rirs, signal)
|
|
46
|
+
return _convolve_dynamic_rir_hop(signal, rirs, hop)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _convolve_dynamic_trajectory(
|
|
50
|
+
signal: Tensor,
|
|
51
|
+
rirs: Tensor,
|
|
52
|
+
*,
|
|
53
|
+
timestamps: Optional[Tensor],
|
|
54
|
+
fs: Optional[float],
|
|
55
|
+
) -> Tensor:
|
|
56
|
+
from .signal import _convolve_dynamic_rir_trajectory
|
|
57
|
+
|
|
58
|
+
signal = _ensure_signal(signal)
|
|
59
|
+
rirs = _ensure_dynamic_rirs(rirs, signal)
|
|
60
|
+
return _convolve_dynamic_rir_trajectory(signal, rirs, timestamps=timestamps, fs=fs)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Logging helpers for torchrir."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class LoggingConfig:
|
|
12
|
+
"""Configuration for torchrir logging."""
|
|
13
|
+
|
|
14
|
+
level: str | int = "INFO"
|
|
15
|
+
format: str = "%(levelname)s:%(name)s:%(message)s"
|
|
16
|
+
datefmt: Optional[str] = None
|
|
17
|
+
propagate: bool = False
|
|
18
|
+
|
|
19
|
+
def resolve_level(self) -> int:
|
|
20
|
+
"""Resolve level to a logging integer constant."""
|
|
21
|
+
if isinstance(self.level, int):
|
|
22
|
+
return self.level
|
|
23
|
+
if not isinstance(self.level, str):
|
|
24
|
+
raise TypeError("level must be str or int")
|
|
25
|
+
name = self.level.upper()
|
|
26
|
+
if name not in logging._nameToLevel:
|
|
27
|
+
raise ValueError(f"unknown log level: {self.level}")
|
|
28
|
+
return logging._nameToLevel[name]
|
|
29
|
+
|
|
30
|
+
def replace(self, **kwargs) -> "LoggingConfig":
|
|
31
|
+
"""Return a new config with updated fields."""
|
|
32
|
+
return replace(self, **kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.Logger:
|
|
36
|
+
"""Configure and return the base torchrir logger."""
|
|
37
|
+
logger = logging.getLogger(name)
|
|
38
|
+
level = config.resolve_level()
|
|
39
|
+
logger.setLevel(level)
|
|
40
|
+
logger.propagate = config.propagate
|
|
41
|
+
if not logger.handlers:
|
|
42
|
+
handler = logging.StreamHandler()
|
|
43
|
+
handler.setLevel(level)
|
|
44
|
+
handler.setFormatter(logging.Formatter(config.format, datefmt=config.datefmt))
|
|
45
|
+
logger.addHandler(handler)
|
|
46
|
+
return logger
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|
50
|
+
"""Return a torchrir logger, namespaced under the torchrir root."""
|
|
51
|
+
if not name:
|
|
52
|
+
return logging.getLogger("torchrir")
|
|
53
|
+
if name.startswith("torchrir"):
|
|
54
|
+
return logging.getLogger(name)
|
|
55
|
+
return logging.getLogger(f"torchrir.{name}")
|