torchrir 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/__init__.py CHANGED
@@ -4,6 +4,8 @@ from .config import SimulationConfig, default_config
4
4
  from .core import simulate_dynamic_rir, simulate_rir
5
5
  from .dynamic import DynamicConvolver
6
6
  from .logging_utils import LoggingConfig, get_logger, setup_logging
7
+ from .animation import animate_scene_gif
8
+ from .metadata import build_metadata, save_metadata_json
7
9
  from .plotting import plot_scene_dynamic, plot_scene_static
8
10
  from .plotting_utils import plot_scene_and_save
9
11
  from .room import MicrophoneArray, Room, Source
@@ -24,7 +26,12 @@ from .datasets import (
24
26
  load_wav_mono,
25
27
  save_wav,
26
28
  )
27
- from .scene_utils import binaural_mic_positions, clamp_positions, linear_trajectory, sample_positions
29
+ from .scene_utils import (
30
+ binaural_mic_positions,
31
+ clamp_positions,
32
+ linear_trajectory,
33
+ sample_positions,
34
+ )
28
35
  from .utils import (
29
36
  att2t_SabineEstimation,
30
37
  att2t_sabine_estimation,
@@ -61,6 +68,8 @@ __all__ = [
61
68
  "get_logger",
62
69
  "list_cmu_arctic_speakers",
63
70
  "LoggingConfig",
71
+ "animate_scene_gif",
72
+ "build_metadata",
64
73
  "resolve_device",
65
74
  "SentenceLike",
66
75
  "load_dataset_sources",
@@ -75,6 +84,7 @@ __all__ = [
75
84
  "plot_scene_and_save",
76
85
  "plot_scene_static",
77
86
  "save_wav",
87
+ "save_metadata_json",
78
88
  "Scene",
79
89
  "setup_logging",
80
90
  "SimulationConfig",
torchrir/animation.py ADDED
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ """Animation helpers for dynamic scenes."""
4
+
5
+ from pathlib import Path
6
+ from typing import Optional, Sequence
7
+
8
+ import torch
9
+
10
+ from .plotting_utils import _positions_to_cpu, _to_cpu, _traj_steps, _trajectory_to_cpu
11
+
12
+
13
+ def animate_scene_gif(
14
+ *,
15
+ out_path: Path,
16
+ room: Sequence[float] | torch.Tensor,
17
+ sources: object | torch.Tensor | Sequence,
18
+ mics: object | torch.Tensor | Sequence,
19
+ src_traj: Optional[torch.Tensor | Sequence] = None,
20
+ mic_traj: Optional[torch.Tensor | Sequence] = None,
21
+ step: int = 1,
22
+ fps: Optional[float] = None,
23
+ signal_len: Optional[int] = None,
24
+ fs: Optional[float] = None,
25
+ duration_s: Optional[float] = None,
26
+ plot_2d: bool = True,
27
+ plot_3d: bool = False,
28
+ ) -> Path:
29
+ """Render a GIF showing source/mic trajectories.
30
+
31
+ Args:
32
+ out_path: Destination GIF path.
33
+ room: Room size tensor or sequence.
34
+ sources: Source positions or Source-like object.
35
+ mics: Microphone positions or MicrophoneArray-like object.
36
+ src_traj: Optional source trajectory (T, n_src, dim).
37
+ mic_traj: Optional mic trajectory (T, n_mic, dim).
38
+ step: Subsampling step for trajectories.
39
+ fps: Frames per second for the GIF (auto if None).
40
+ signal_len: Optional signal length (samples) to infer elapsed time.
41
+ fs: Sample rate used with signal_len.
42
+ duration_s: Optional total duration in seconds (overrides signal_len/fs).
43
+ plot_2d: Use 2D projection if True.
44
+ plot_3d: Use 3D projection if True and dim == 3.
45
+
46
+ Returns:
47
+ The output path.
48
+
49
+ Example:
50
+ >>> animate_scene_gif(
51
+ ... out_path=Path("outputs/scene.gif"),
52
+ ... room=[6.0, 4.0, 3.0],
53
+ ... sources=[[1.0, 2.0, 1.5]],
54
+ ... mics=[[2.0, 2.0, 1.5]],
55
+ ... src_traj=src_traj,
56
+ ... mic_traj=mic_traj,
57
+ ... signal_len=16000,
58
+ ... fs=16000,
59
+ ... )
60
+ """
61
+ import matplotlib.pyplot as plt
62
+ from matplotlib import animation
63
+
64
+ out_path = Path(out_path)
65
+ out_path.parent.mkdir(parents=True, exist_ok=True)
66
+
67
+ room_size = _to_cpu(room)
68
+ src_pos = _positions_to_cpu(sources)
69
+ mic_pos = _positions_to_cpu(mics)
70
+ dim = int(room_size.numel())
71
+ view_dim = 3 if (plot_3d and dim == 3) else 2
72
+ view_room = room_size[:view_dim]
73
+ view_src = src_pos[:, :view_dim]
74
+ view_mic = mic_pos[:, :view_dim]
75
+
76
+ if src_traj is None and mic_traj is None:
77
+ raise ValueError("at least one trajectory is required for animation")
78
+ steps = _traj_steps(src_traj, mic_traj)
79
+ src_traj = _trajectory_to_cpu(src_traj, src_pos, steps)
80
+ mic_traj = _trajectory_to_cpu(mic_traj, mic_pos, steps)
81
+ view_src_traj = src_traj[:, :, :view_dim]
82
+ view_mic_traj = mic_traj[:, :, :view_dim]
83
+
84
+ if view_dim == 3:
85
+ fig = plt.figure()
86
+ ax = fig.add_subplot(111, projection="3d")
87
+ ax.set_xlim(0, view_room[0].item())
88
+ ax.set_ylim(0, view_room[1].item())
89
+ ax.set_zlim(0, view_room[2].item())
90
+ ax.set_xlabel("x")
91
+ ax.set_ylabel("y")
92
+ ax.set_zlabel("z")
93
+ else:
94
+ fig, ax = plt.subplots()
95
+ ax.set_xlim(0, view_room[0].item())
96
+ ax.set_ylim(0, view_room[1].item())
97
+ ax.set_aspect("equal", adjustable="box")
98
+ ax.set_xlabel("x")
99
+ ax.set_ylabel("y")
100
+
101
+ src_scatter = ax.scatter([], [], marker="^", color="tab:green", label="sources")
102
+ mic_scatter = ax.scatter([], [], marker="o", color="tab:orange", label="mics")
103
+ src_lines = []
104
+ mic_lines = []
105
+ for _ in range(view_src_traj.shape[1]):
106
+ if view_dim == 2:
107
+ (line,) = ax.plot([], [], color="tab:green", alpha=0.6)
108
+ else:
109
+ (line,) = ax.plot([], [], [], color="tab:green", alpha=0.6)
110
+ src_lines.append(line)
111
+ for _ in range(view_mic_traj.shape[1]):
112
+ if view_dim == 2:
113
+ (line,) = ax.plot([], [], color="tab:orange", alpha=0.6)
114
+ else:
115
+ (line,) = ax.plot([], [], [], color="tab:orange", alpha=0.6)
116
+ mic_lines.append(line)
117
+
118
+ ax.legend(loc="best")
119
+
120
+ if duration_s is None and signal_len is not None and fs is not None:
121
+ duration_s = float(signal_len) / float(fs)
122
+
123
+ def _frame(i: int):
124
+ idx = min(i * step, view_src_traj.shape[0] - 1)
125
+ src_frame = view_src_traj[: idx + 1]
126
+ mic_frame = view_mic_traj[: idx + 1]
127
+ src_pos_frame = view_src_traj[idx]
128
+ mic_pos_frame = view_mic_traj[idx]
129
+
130
+ if view_dim == 2:
131
+ src_scatter.set_offsets(src_pos_frame)
132
+ mic_scatter.set_offsets(mic_pos_frame)
133
+ for s_idx, line in enumerate(src_lines):
134
+ xy = src_frame[:, s_idx, :]
135
+ line.set_data(xy[:, 0], xy[:, 1])
136
+ for m_idx, line in enumerate(mic_lines):
137
+ xy = mic_frame[:, m_idx, :]
138
+ line.set_data(xy[:, 0], xy[:, 1])
139
+ else:
140
+ setattr(
141
+ src_scatter,
142
+ "_offsets3d",
143
+ (src_pos_frame[:, 0], src_pos_frame[:, 1], src_pos_frame[:, 2]),
144
+ )
145
+ setattr(
146
+ mic_scatter,
147
+ "_offsets3d",
148
+ (mic_pos_frame[:, 0], mic_pos_frame[:, 1], mic_pos_frame[:, 2]),
149
+ )
150
+ for s_idx, line in enumerate(src_lines):
151
+ xyz = src_frame[:, s_idx, :]
152
+ line.set_data(xyz[:, 0], xyz[:, 1])
153
+ line.set_3d_properties(xyz[:, 2])
154
+ for m_idx, line in enumerate(mic_lines):
155
+ xyz = mic_frame[:, m_idx, :]
156
+ line.set_data(xyz[:, 0], xyz[:, 1])
157
+ line.set_3d_properties(xyz[:, 2])
158
+ if duration_s is not None and steps > 1:
159
+ t = (idx / (steps - 1)) * duration_s
160
+ ax.set_title(f"t = {t:.2f} s")
161
+ return [src_scatter, mic_scatter, *src_lines, *mic_lines]
162
+
163
+ frames = max(1, (view_src_traj.shape[0] + step - 1) // step)
164
+ if fps is None or fps <= 0:
165
+ if duration_s is not None and duration_s > 0:
166
+ fps = frames / duration_s
167
+ else:
168
+ fps = 6.0
169
+ anim = animation.FuncAnimation(
170
+ fig, _frame, frames=frames, interval=1000 / fps, blit=False
171
+ )
172
+ fps_int = None if fps is None else max(1, int(round(fps)))
173
+ anim.save(out_path, writer="pillow", fps=fps_int)
174
+ plt.close(fig)
175
+ return out_path
torchrir/config.py CHANGED
@@ -10,7 +10,12 @@ import torch
10
10
 
11
11
  @dataclass(frozen=True)
12
12
  class SimulationConfig:
13
- """Configuration values for RIR simulation and convolution."""
13
+ """Configuration values for RIR simulation and convolution.
14
+
15
+ Example:
16
+ >>> cfg = SimulationConfig(max_order=6, tmax=0.3, device="auto")
17
+ >>> cfg.validate()
18
+ """
14
19
 
15
20
  fs: Optional[float] = None
16
21
  max_order: Optional[int] = None
@@ -53,7 +58,11 @@ class SimulationConfig:
53
58
 
54
59
 
55
60
  def default_config() -> SimulationConfig:
56
- """Return the default simulation configuration."""
61
+ """Return the default simulation configuration.
62
+
63
+ Example:
64
+ >>> cfg = default_config()
65
+ """
57
66
  cfg = SimulationConfig()
58
67
  cfg.validate()
59
68
  return cfg
torchrir/core.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  """Core RIR simulation functions (static and dynamic)."""
4
4
 
5
5
  import math
6
+ from collections.abc import Callable
6
7
  from typing import Optional, Tuple
7
8
 
8
9
  import torch
@@ -58,6 +59,18 @@ def simulate_rir(
58
59
 
59
60
  Returns:
60
61
  Tensor of shape (n_src, n_mic, nsample).
62
+
63
+ Example:
64
+ >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
65
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
66
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
67
+ >>> rir = simulate_rir(
68
+ ... room=room,
69
+ ... sources=sources,
70
+ ... mics=mics,
71
+ ... max_order=6,
72
+ ... tmax=0.3,
73
+ ... )
61
74
  """
62
75
  cfg = config or default_config()
63
76
  cfg.validate()
@@ -78,9 +91,9 @@ def simulate_rir(
78
91
 
79
92
  if not isinstance(room, Room):
80
93
  raise TypeError("room must be a Room instance")
81
- if nsample is None and tmax is None:
82
- raise ValueError("nsample or tmax must be provided")
83
94
  if nsample is None:
95
+ if tmax is None:
96
+ raise ValueError("nsample or tmax must be provided")
84
97
  nsample = int(math.ceil(tmax * room.fs))
85
98
  if nsample <= 0:
86
99
  raise ValueError("nsample must be positive")
@@ -208,6 +221,24 @@ def simulate_dynamic_rir(
208
221
 
209
222
  Returns:
210
223
  Tensor of shape (T, n_src, n_mic, nsample).
224
+
225
+ Example:
226
+ >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
227
+ >>> from torchrir import linear_trajectory
228
+ >>> src_traj = torch.stack(
229
+ ... [linear_trajectory(torch.tensor([1.0, 2.0, 1.5]),
230
+ ... torch.tensor([4.0, 2.0, 1.5]), 8)],
231
+ ... dim=1,
232
+ ... )
233
+ >>> mic_pos = torch.tensor([[2.0, 2.0, 1.5]])
234
+ >>> mic_traj = mic_pos.unsqueeze(0).repeat(8, 1, 1)
235
+ >>> rirs = simulate_dynamic_rir(
236
+ ... room=room,
237
+ ... src_traj=src_traj,
238
+ ... mic_traj=mic_traj,
239
+ ... max_order=4,
240
+ ... tmax=0.3,
241
+ ... )
211
242
  """
212
243
  cfg = config or default_config()
213
244
  cfg.validate()
@@ -465,7 +496,11 @@ def _compute_image_contributions_batch(
465
496
  if mic_pattern != "omni":
466
497
  if mic_dir is None:
467
498
  raise ValueError("mic orientation required for non-omni directivity")
468
- mic_dir = mic_dir[None, :, None, :] if mic_dir.ndim == 2 else mic_dir.view(1, 1, 1, -1)
499
+ mic_dir = (
500
+ mic_dir[None, :, None, :]
501
+ if mic_dir.ndim == 2
502
+ else mic_dir.view(1, 1, 1, -1)
503
+ )
469
504
  cos_theta = _cos_between(-vec, mic_dir)
470
505
  gain = gain * directivity_gain(mic_pattern, cos_theta)
471
506
 
@@ -512,9 +547,9 @@ def _accumulate_rir(
512
547
  if use_lut:
513
548
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
514
549
 
515
- mic_offsets = (torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample).view(
516
- n_mic, 1, 1
517
- )
550
+ mic_offsets = (
551
+ torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample
552
+ ).view(n_mic, 1, 1)
518
553
  rir_flat = rir.view(-1)
519
554
 
520
555
  chunk_size = cfg.accumulate_chunk_size
@@ -529,7 +564,9 @@ def _accumulate_rir(
529
564
  x_off_frac = (1.0 - frac_m) * lut_gran
530
565
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
531
566
  x_off = x_off_frac - lut_gran_off.to(dtype)
532
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
567
+ lut_pos = lut_gran_off[..., None] + (
568
+ n[None, None, :].to(torch.int64) * lut_gran
569
+ )
533
570
 
534
571
  s0 = torch.take(sinc_lut, lut_pos)
535
572
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -588,9 +625,9 @@ def _accumulate_rir_batch_impl(
588
625
  if use_lut:
589
626
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
590
627
 
591
- sm_offsets = (torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample).view(
592
- n_sm, 1, 1
593
- )
628
+ sm_offsets = (
629
+ torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample
630
+ ).view(n_sm, 1, 1)
594
631
  rir_flat = rir.view(-1)
595
632
 
596
633
  n_img = idx0.shape[1]
@@ -604,7 +641,9 @@ def _accumulate_rir_batch_impl(
604
641
  x_off_frac = (1.0 - frac_m) * lut_gran
605
642
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
606
643
  x_off = x_off_frac - lut_gran_off.to(sample.dtype)
607
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
644
+ lut_pos = lut_gran_off[..., None] + (
645
+ n[None, None, :].to(torch.int64) * lut_gran
646
+ )
608
647
 
609
648
  s0 = torch.take(sinc_lut, lut_pos)
610
649
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -630,12 +669,13 @@ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
630
669
  _FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
631
670
  _FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
632
671
  _FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
633
- _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], callable] = {}
672
+ _AccumFn = Callable[[Tensor, Tensor, Tensor], None]
673
+ _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], _AccumFn] = {}
634
674
 
635
675
 
636
676
  def _get_accumulate_fn(
637
677
  cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
638
- ) -> callable:
678
+ ) -> _AccumFn:
639
679
  """Return an accumulation function with config-bound constants."""
640
680
  use_lut = cfg.use_lut and device.type != "mps"
641
681
  fdl = cfg.frac_delay_length
@@ -691,7 +731,9 @@ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Te
691
731
  return cached
692
732
 
693
733
 
694
- def _get_sinc_lut(fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
734
+ def _get_sinc_lut(
735
+ fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype
736
+ ) -> Tensor:
695
737
  """Create a sinc lookup table for fractional delays."""
696
738
  key = (fdl, lut_gran, str(device), dtype)
697
739
  cached = _SINC_LUT_CACHE.get(key)
@@ -735,7 +777,12 @@ def _apply_diffuse_tail(
735
777
 
736
778
  gen = torch.Generator(device=rir.device)
737
779
  gen.manual_seed(0 if seed is None else seed)
738
- noise = torch.randn(rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen)
739
- scale = torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True) + 1e-8
780
+ noise = torch.randn(
781
+ rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen
782
+ )
783
+ scale = (
784
+ torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True)
785
+ + 1e-8
786
+ )
740
787
  rir[..., tdiff_idx:] = noise * decay * scale
741
788
  return rir
@@ -44,12 +44,22 @@ def list_cmu_arctic_speakers() -> List[str]:
44
44
  @dataclass
45
45
  class CmuArcticSentence:
46
46
  """Sentence metadata from CMU ARCTIC."""
47
+
47
48
  utterance_id: str
48
49
  text: str
49
50
 
50
51
 
51
52
  class CmuArcticDataset:
52
- def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
53
+ """CMU ARCTIC dataset loader.
54
+
55
+ Example:
56
+ >>> dataset = CmuArcticDataset(Path("datasets/cmu_arctic"), speaker="bdl", download=True)
57
+ >>> audio, fs = dataset.load_wav("arctic_a0001")
58
+ """
59
+
60
+ def __init__(
61
+ self, root: Path, speaker: str = "bdl", download: bool = False
62
+ ) -> None:
53
63
  """Initialize a CMU ARCTIC dataset handle.
54
64
 
55
65
  Args:
@@ -182,7 +192,11 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
182
192
 
183
193
 
184
194
  def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
185
- """Load a wav file and return mono audio and sample rate."""
195
+ """Load a wav file and return mono audio and sample rate.
196
+
197
+ Example:
198
+ >>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
199
+ """
186
200
  import soundfile as sf
187
201
 
188
202
  audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
@@ -195,7 +209,11 @@ def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
195
209
 
196
210
 
197
211
  def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
198
- """Save a mono or multi-channel wav to disk."""
212
+ """Save a mono or multi-channel wav to disk.
213
+
214
+ Example:
215
+ >>> save_wav(Path("outputs/example.wav"), audio, sample_rate)
216
+ """
199
217
  import soundfile as sf
200
218
 
201
219
  audio = audio.detach().cpu().clamp(-1.0, 1.0).to(torch.float32)
@@ -34,7 +34,9 @@ class TemplateDataset(BaseDataset):
34
34
  protocol intact.
35
35
  """
36
36
 
37
- def __init__(self, root: Path, speaker: str = "default", download: bool = False) -> None:
37
+ def __init__(
38
+ self, root: Path, speaker: str = "default", download: bool = False
39
+ ) -> None:
38
40
  self.root = Path(root)
39
41
  self.speaker = speaker
40
42
  if download:
@@ -10,8 +10,15 @@ import torch
10
10
  from .base import BaseDataset, SentenceLike
11
11
 
12
12
 
13
- def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
14
- """Select unique speakers for the requested number of sources."""
13
+ def choose_speakers(
14
+ dataset: BaseDataset, num_sources: int, rng: random.Random
15
+ ) -> List[str]:
16
+ """Select unique speakers for the requested number of sources.
17
+
18
+ Example:
19
+ >>> rng = random.Random(0)
20
+ >>> speakers = choose_speakers(dataset, num_sources=2, rng=rng)
21
+ """
15
22
  speakers = dataset.list_speakers()
16
23
  if not speakers:
17
24
  raise RuntimeError("no speakers available")
@@ -27,7 +34,20 @@ def load_dataset_sources(
27
34
  duration_s: float,
28
35
  rng: random.Random,
29
36
  ) -> Tuple[torch.Tensor, int, List[Tuple[str, List[str]]]]:
30
- """Load and concatenate utterances for each speaker into fixed-length signals."""
37
+ """Load and concatenate utterances for each speaker into fixed-length signals.
38
+
39
+ Example:
40
+ >>> from pathlib import Path
41
+ >>> from torchrir import CmuArcticDataset
42
+ >>> rng = random.Random(0)
43
+ >>> root = Path("datasets/cmu_arctic")
44
+ >>> signals, fs, info = load_dataset_sources(
45
+ ... dataset_factory=lambda spk: CmuArcticDataset(root, speaker=spk, download=True),
46
+ ... num_sources=2,
47
+ ... duration_s=10.0,
48
+ ... rng=rng,
49
+ ... )
50
+ """
31
51
  dataset0 = dataset_factory(None)
32
52
  speakers = choose_speakers(dataset0, num_sources, rng)
33
53
  signals: List[torch.Tensor] = []
@@ -71,4 +91,6 @@ def load_dataset_sources(
71
91
  info.append((speaker, utterance_ids))
72
92
 
73
93
  stacked = torch.stack(signals, dim=0)
94
+ if fs is None:
95
+ raise RuntimeError("no audio loaded from dataset sources")
74
96
  return stacked, int(fs), info
torchrir/dynamic.py CHANGED
@@ -17,7 +17,12 @@ from .signal import _ensure_dynamic_rirs, _ensure_signal
17
17
 
18
18
  @dataclass(frozen=True)
19
19
  class DynamicConvolver:
20
- """Convolver for time-varying RIRs."""
20
+ """Convolver for time-varying RIRs.
21
+
22
+ Example:
23
+ >>> convolver = DynamicConvolver(mode="trajectory")
24
+ >>> y = convolver.convolve(signal, rirs)
25
+ """
21
26
 
22
27
  mode: str = "trajectory"
23
28
  hop: Optional[int] = None
@@ -28,14 +33,20 @@ class DynamicConvolver:
28
33
  return self.convolve(signal, rirs)
29
34
 
30
35
  def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
31
- """Convolve signals with time-varying RIRs."""
36
+ """Convolve signals with time-varying RIRs.
37
+
38
+ Example:
39
+ >>> y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
40
+ """
32
41
  if self.mode not in ("trajectory", "hop"):
33
42
  raise ValueError("mode must be 'trajectory' or 'hop'")
34
43
  if self.mode == "hop":
35
44
  if self.hop is None:
36
45
  raise ValueError("hop must be provided for hop mode")
37
46
  return _convolve_dynamic_hop(signal, rirs, self.hop)
38
- return _convolve_dynamic_trajectory(signal, rirs, timestamps=self.timestamps, fs=self.fs)
47
+ return _convolve_dynamic_trajectory(
48
+ signal, rirs, timestamps=self.timestamps, fs=self.fs
49
+ )
39
50
 
40
51
 
41
52
  def _convolve_dynamic_hop(signal: Tensor, rirs: Tensor, hop: int) -> Tensor:
torchrir/logging_utils.py CHANGED
@@ -9,7 +9,12 @@ from typing import Optional
9
9
 
10
10
  @dataclass(frozen=True)
11
11
  class LoggingConfig:
12
- """Configuration for torchrir logging."""
12
+ """Configuration for torchrir logging.
13
+
14
+ Example:
15
+ >>> config = LoggingConfig(level="INFO")
16
+ >>> logger = setup_logging(config)
17
+ """
13
18
 
14
19
  level: str | int = "INFO"
15
20
  format: str = "%(levelname)s:%(name)s:%(message)s"
@@ -33,7 +38,12 @@ class LoggingConfig:
33
38
 
34
39
 
35
40
  def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.Logger:
36
- """Configure and return the base torchrir logger."""
41
+ """Configure and return the base torchrir logger.
42
+
43
+ Example:
44
+ >>> logger = setup_logging(LoggingConfig(level="DEBUG"))
45
+ >>> logger.info("ready")
46
+ """
37
47
  logger = logging.getLogger(name)
38
48
  level = config.resolve_level()
39
49
  logger.setLevel(level)
@@ -47,7 +57,11 @@ def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.L
47
57
 
48
58
 
49
59
  def get_logger(name: Optional[str] = None) -> logging.Logger:
50
- """Return a torchrir logger, namespaced under the torchrir root."""
60
+ """Return a torchrir logger, namespaced under the torchrir root.
61
+
62
+ Example:
63
+ >>> logger = get_logger("examples.static")
64
+ """
51
65
  if not name:
52
66
  return logging.getLogger("torchrir")
53
67
  if name.startswith("torchrir"):