torchrir 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/__init__.py ADDED
@@ -0,0 +1,85 @@
1
+ """TorchRIR public API."""
2
+
3
+ from .config import SimulationConfig, default_config
4
+ from .core import simulate_dynamic_rir, simulate_rir
5
+ from .dynamic import DynamicConvolver
6
+ from .logging_utils import LoggingConfig, get_logger, setup_logging
7
+ from .plotting import plot_scene_dynamic, plot_scene_static
8
+ from .plotting_utils import plot_scene_and_save
9
+ from .room import MicrophoneArray, Room, Source
10
+ from .scene import Scene
11
+ from .results import RIRResult
12
+ from .simulators import FDTDSimulator, ISMSimulator, RIRSimulator, RayTracingSimulator
13
+ from .signal import convolve_rir, fft_convolve
14
+ from .datasets import (
15
+ BaseDataset,
16
+ CmuArcticDataset,
17
+ CmuArcticSentence,
18
+ choose_speakers,
19
+ list_cmu_arctic_speakers,
20
+ SentenceLike,
21
+ load_dataset_sources,
22
+ TemplateDataset,
23
+ TemplateSentence,
24
+ load_wav_mono,
25
+ save_wav,
26
+ )
27
+ from .scene_utils import binaural_mic_positions, clamp_positions, linear_trajectory, sample_positions
28
+ from .utils import (
29
+ att2t_SabineEstimation,
30
+ att2t_sabine_estimation,
31
+ beta_SabineEstimation,
32
+ DeviceSpec,
33
+ estimate_beta_from_t60,
34
+ estimate_t60_from_beta,
35
+ resolve_device,
36
+ t2n,
37
+ )
38
+
39
+ __all__ = [
40
+ "MicrophoneArray",
41
+ "Room",
42
+ "Source",
43
+ "RIRResult",
44
+ "RIRSimulator",
45
+ "ISMSimulator",
46
+ "RayTracingSimulator",
47
+ "FDTDSimulator",
48
+ "convolve_rir",
49
+ "att2t_SabineEstimation",
50
+ "att2t_sabine_estimation",
51
+ "beta_SabineEstimation",
52
+ "DeviceSpec",
53
+ "BaseDataset",
54
+ "CmuArcticDataset",
55
+ "CmuArcticSentence",
56
+ "choose_speakers",
57
+ "DynamicConvolver",
58
+ "estimate_beta_from_t60",
59
+ "estimate_t60_from_beta",
60
+ "fft_convolve",
61
+ "get_logger",
62
+ "list_cmu_arctic_speakers",
63
+ "LoggingConfig",
64
+ "resolve_device",
65
+ "SentenceLike",
66
+ "load_dataset_sources",
67
+ "load_wav_mono",
68
+ "TemplateDataset",
69
+ "TemplateSentence",
70
+ "binaural_mic_positions",
71
+ "clamp_positions",
72
+ "linear_trajectory",
73
+ "sample_positions",
74
+ "plot_scene_dynamic",
75
+ "plot_scene_and_save",
76
+ "plot_scene_static",
77
+ "save_wav",
78
+ "Scene",
79
+ "setup_logging",
80
+ "SimulationConfig",
81
+ "default_config",
82
+ "simulate_dynamic_rir",
83
+ "simulate_rir",
84
+ "t2n",
85
+ ]
torchrir/config.py ADDED
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ """Simulation configuration for torchrir."""
4
+
5
+ from dataclasses import dataclass, replace
6
+ from typing import Optional
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class SimulationConfig:
13
+ """Configuration values for RIR simulation and convolution."""
14
+
15
+ fs: Optional[float] = None
16
+ max_order: Optional[int] = None
17
+ tmax: Optional[float] = None
18
+ directivity: Optional[str | tuple[str, str]] = None
19
+ device: Optional[torch.device | str] = None
20
+ seed: Optional[int] = None
21
+ use_lut: bool = True
22
+ mixed_precision: bool = False
23
+ frac_delay_length: int = 81
24
+ sinc_lut_granularity: int = 20
25
+ image_chunk_size: int = 2048
26
+ accumulate_chunk_size: int = 4096
27
+ use_compile: bool = False
28
+
29
+ def validate(self) -> None:
30
+ """Validate configuration values."""
31
+ if self.fs is not None and self.fs <= 0:
32
+ raise ValueError("fs must be positive")
33
+ if self.max_order is not None and self.max_order < 0:
34
+ raise ValueError("max_order must be non-negative")
35
+ if self.tmax is not None and self.tmax <= 0:
36
+ raise ValueError("tmax must be positive")
37
+ if self.seed is not None and self.seed < 0:
38
+ raise ValueError("seed must be non-negative")
39
+ if self.frac_delay_length <= 0 or self.frac_delay_length % 2 == 0:
40
+ raise ValueError("frac_delay_length must be a positive odd integer")
41
+ if self.sinc_lut_granularity <= 0:
42
+ raise ValueError("sinc_lut_granularity must be positive")
43
+ if self.image_chunk_size <= 0:
44
+ raise ValueError("image_chunk_size must be positive")
45
+ if self.accumulate_chunk_size <= 0:
46
+ raise ValueError("accumulate_chunk_size must be positive")
47
+
48
+ def replace(self, **kwargs) -> "SimulationConfig":
49
+ """Return a new config with updated fields."""
50
+ new_cfg = replace(self, **kwargs)
51
+ new_cfg.validate()
52
+ return new_cfg
53
+
54
+
55
+ def default_config() -> SimulationConfig:
56
+ """Return the default simulation configuration."""
57
+ cfg = SimulationConfig()
58
+ cfg.validate()
59
+ return cfg