torchrir 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +85 -0
- torchrir/config.py +59 -0
- torchrir/core.py +741 -0
- torchrir/datasets/__init__.py +27 -0
- torchrir/datasets/base.py +27 -0
- torchrir/datasets/cmu_arctic.py +204 -0
- torchrir/datasets/template.py +65 -0
- torchrir/datasets/utils.py +74 -0
- torchrir/directivity.py +33 -0
- torchrir/dynamic.py +60 -0
- torchrir/logging_utils.py +55 -0
- torchrir/plotting.py +210 -0
- torchrir/plotting_utils.py +173 -0
- torchrir/results.py +22 -0
- torchrir/room.py +150 -0
- torchrir/scene.py +67 -0
- torchrir/scene_utils.py +51 -0
- torchrir/signal.py +233 -0
- torchrir/simulators.py +86 -0
- torchrir/utils.py +281 -0
- torchrir-0.1.0.dist-info/METADATA +213 -0
- torchrir-0.1.0.dist-info/RECORD +26 -0
- torchrir-0.1.0.dist-info/WHEEL +5 -0
- torchrir-0.1.0.dist-info/licenses/LICENSE +190 -0
- torchrir-0.1.0.dist-info/licenses/NOTICE +4 -0
- torchrir-0.1.0.dist-info/top_level.txt +1 -0
torchrir/__init__.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""TorchRIR public API."""
|
|
2
|
+
|
|
3
|
+
from .config import SimulationConfig, default_config
|
|
4
|
+
from .core import simulate_dynamic_rir, simulate_rir
|
|
5
|
+
from .dynamic import DynamicConvolver
|
|
6
|
+
from .logging_utils import LoggingConfig, get_logger, setup_logging
|
|
7
|
+
from .plotting import plot_scene_dynamic, plot_scene_static
|
|
8
|
+
from .plotting_utils import plot_scene_and_save
|
|
9
|
+
from .room import MicrophoneArray, Room, Source
|
|
10
|
+
from .scene import Scene
|
|
11
|
+
from .results import RIRResult
|
|
12
|
+
from .simulators import FDTDSimulator, ISMSimulator, RIRSimulator, RayTracingSimulator
|
|
13
|
+
from .signal import convolve_rir, fft_convolve
|
|
14
|
+
from .datasets import (
|
|
15
|
+
BaseDataset,
|
|
16
|
+
CmuArcticDataset,
|
|
17
|
+
CmuArcticSentence,
|
|
18
|
+
choose_speakers,
|
|
19
|
+
list_cmu_arctic_speakers,
|
|
20
|
+
SentenceLike,
|
|
21
|
+
load_dataset_sources,
|
|
22
|
+
TemplateDataset,
|
|
23
|
+
TemplateSentence,
|
|
24
|
+
load_wav_mono,
|
|
25
|
+
save_wav,
|
|
26
|
+
)
|
|
27
|
+
from .scene_utils import binaural_mic_positions, clamp_positions, linear_trajectory, sample_positions
|
|
28
|
+
from .utils import (
|
|
29
|
+
att2t_SabineEstimation,
|
|
30
|
+
att2t_sabine_estimation,
|
|
31
|
+
beta_SabineEstimation,
|
|
32
|
+
DeviceSpec,
|
|
33
|
+
estimate_beta_from_t60,
|
|
34
|
+
estimate_t60_from_beta,
|
|
35
|
+
resolve_device,
|
|
36
|
+
t2n,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"MicrophoneArray",
|
|
41
|
+
"Room",
|
|
42
|
+
"Source",
|
|
43
|
+
"RIRResult",
|
|
44
|
+
"RIRSimulator",
|
|
45
|
+
"ISMSimulator",
|
|
46
|
+
"RayTracingSimulator",
|
|
47
|
+
"FDTDSimulator",
|
|
48
|
+
"convolve_rir",
|
|
49
|
+
"att2t_SabineEstimation",
|
|
50
|
+
"att2t_sabine_estimation",
|
|
51
|
+
"beta_SabineEstimation",
|
|
52
|
+
"DeviceSpec",
|
|
53
|
+
"BaseDataset",
|
|
54
|
+
"CmuArcticDataset",
|
|
55
|
+
"CmuArcticSentence",
|
|
56
|
+
"choose_speakers",
|
|
57
|
+
"DynamicConvolver",
|
|
58
|
+
"estimate_beta_from_t60",
|
|
59
|
+
"estimate_t60_from_beta",
|
|
60
|
+
"fft_convolve",
|
|
61
|
+
"get_logger",
|
|
62
|
+
"list_cmu_arctic_speakers",
|
|
63
|
+
"LoggingConfig",
|
|
64
|
+
"resolve_device",
|
|
65
|
+
"SentenceLike",
|
|
66
|
+
"load_dataset_sources",
|
|
67
|
+
"load_wav_mono",
|
|
68
|
+
"TemplateDataset",
|
|
69
|
+
"TemplateSentence",
|
|
70
|
+
"binaural_mic_positions",
|
|
71
|
+
"clamp_positions",
|
|
72
|
+
"linear_trajectory",
|
|
73
|
+
"sample_positions",
|
|
74
|
+
"plot_scene_dynamic",
|
|
75
|
+
"plot_scene_and_save",
|
|
76
|
+
"plot_scene_static",
|
|
77
|
+
"save_wav",
|
|
78
|
+
"Scene",
|
|
79
|
+
"setup_logging",
|
|
80
|
+
"SimulationConfig",
|
|
81
|
+
"default_config",
|
|
82
|
+
"simulate_dynamic_rir",
|
|
83
|
+
"simulate_rir",
|
|
84
|
+
"t2n",
|
|
85
|
+
]
|
torchrir/config.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Simulation configuration for torchrir."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class SimulationConfig:
|
|
13
|
+
"""Configuration values for RIR simulation and convolution."""
|
|
14
|
+
|
|
15
|
+
fs: Optional[float] = None
|
|
16
|
+
max_order: Optional[int] = None
|
|
17
|
+
tmax: Optional[float] = None
|
|
18
|
+
directivity: Optional[str | tuple[str, str]] = None
|
|
19
|
+
device: Optional[torch.device | str] = None
|
|
20
|
+
seed: Optional[int] = None
|
|
21
|
+
use_lut: bool = True
|
|
22
|
+
mixed_precision: bool = False
|
|
23
|
+
frac_delay_length: int = 81
|
|
24
|
+
sinc_lut_granularity: int = 20
|
|
25
|
+
image_chunk_size: int = 2048
|
|
26
|
+
accumulate_chunk_size: int = 4096
|
|
27
|
+
use_compile: bool = False
|
|
28
|
+
|
|
29
|
+
def validate(self) -> None:
|
|
30
|
+
"""Validate configuration values."""
|
|
31
|
+
if self.fs is not None and self.fs <= 0:
|
|
32
|
+
raise ValueError("fs must be positive")
|
|
33
|
+
if self.max_order is not None and self.max_order < 0:
|
|
34
|
+
raise ValueError("max_order must be non-negative")
|
|
35
|
+
if self.tmax is not None and self.tmax <= 0:
|
|
36
|
+
raise ValueError("tmax must be positive")
|
|
37
|
+
if self.seed is not None and self.seed < 0:
|
|
38
|
+
raise ValueError("seed must be non-negative")
|
|
39
|
+
if self.frac_delay_length <= 0 or self.frac_delay_length % 2 == 0:
|
|
40
|
+
raise ValueError("frac_delay_length must be a positive odd integer")
|
|
41
|
+
if self.sinc_lut_granularity <= 0:
|
|
42
|
+
raise ValueError("sinc_lut_granularity must be positive")
|
|
43
|
+
if self.image_chunk_size <= 0:
|
|
44
|
+
raise ValueError("image_chunk_size must be positive")
|
|
45
|
+
if self.accumulate_chunk_size <= 0:
|
|
46
|
+
raise ValueError("accumulate_chunk_size must be positive")
|
|
47
|
+
|
|
48
|
+
def replace(self, **kwargs) -> "SimulationConfig":
|
|
49
|
+
"""Return a new config with updated fields."""
|
|
50
|
+
new_cfg = replace(self, **kwargs)
|
|
51
|
+
new_cfg.validate()
|
|
52
|
+
return new_cfg
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def default_config() -> SimulationConfig:
|
|
56
|
+
"""Return the default simulation configuration."""
|
|
57
|
+
cfg = SimulationConfig()
|
|
58
|
+
cfg.validate()
|
|
59
|
+
return cfg
|