torchrir 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/__init__.py CHANGED
@@ -18,6 +18,11 @@ from .datasets import (
18
18
  CmuArcticDataset,
19
19
  CmuArcticSentence,
20
20
  choose_speakers,
21
+ CollateBatch,
22
+ collate_dataset_items,
23
+ DatasetItem,
24
+ LibriSpeechDataset,
25
+ LibriSpeechSentence,
21
26
  list_cmu_arctic_speakers,
22
27
  SentenceLike,
23
28
  load_dataset_sources,
@@ -26,7 +31,12 @@ from .datasets import (
26
31
  load_wav_mono,
27
32
  save_wav,
28
33
  )
29
- from .scene_utils import binaural_mic_positions, clamp_positions, linear_trajectory, sample_positions
34
+ from .scene_utils import (
35
+ binaural_mic_positions,
36
+ clamp_positions,
37
+ linear_trajectory,
38
+ sample_positions,
39
+ )
30
40
  from .utils import (
31
41
  att2t_SabineEstimation,
32
42
  att2t_sabine_estimation,
@@ -56,6 +66,11 @@ __all__ = [
56
66
  "CmuArcticDataset",
57
67
  "CmuArcticSentence",
58
68
  "choose_speakers",
69
+ "CollateBatch",
70
+ "collate_dataset_items",
71
+ "DatasetItem",
72
+ "LibriSpeechDataset",
73
+ "LibriSpeechSentence",
59
74
  "DynamicConvolver",
60
75
  "estimate_beta_from_t60",
61
76
  "estimate_t60_from_beta",
torchrir/animation.py CHANGED
@@ -104,15 +104,15 @@ def animate_scene_gif(
104
104
  mic_lines = []
105
105
  for _ in range(view_src_traj.shape[1]):
106
106
  if view_dim == 2:
107
- line, = ax.plot([], [], color="tab:green", alpha=0.6)
107
+ (line,) = ax.plot([], [], color="tab:green", alpha=0.6)
108
108
  else:
109
- line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
109
+ (line,) = ax.plot([], [], [], color="tab:green", alpha=0.6)
110
110
  src_lines.append(line)
111
111
  for _ in range(view_mic_traj.shape[1]):
112
112
  if view_dim == 2:
113
- line, = ax.plot([], [], color="tab:orange", alpha=0.6)
113
+ (line,) = ax.plot([], [], color="tab:orange", alpha=0.6)
114
114
  else:
115
- line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
115
+ (line,) = ax.plot([], [], [], color="tab:orange", alpha=0.6)
116
116
  mic_lines.append(line)
117
117
 
118
118
  ax.legend(loc="best")
@@ -137,15 +137,15 @@ def animate_scene_gif(
137
137
  xy = mic_frame[:, m_idx, :]
138
138
  line.set_data(xy[:, 0], xy[:, 1])
139
139
  else:
140
- src_scatter._offsets3d = (
141
- src_pos_frame[:, 0],
142
- src_pos_frame[:, 1],
143
- src_pos_frame[:, 2],
140
+ setattr(
141
+ src_scatter,
142
+ "_offsets3d",
143
+ (src_pos_frame[:, 0], src_pos_frame[:, 1], src_pos_frame[:, 2]),
144
144
  )
145
- mic_scatter._offsets3d = (
146
- mic_pos_frame[:, 0],
147
- mic_pos_frame[:, 1],
148
- mic_pos_frame[:, 2],
145
+ setattr(
146
+ mic_scatter,
147
+ "_offsets3d",
148
+ (mic_pos_frame[:, 0], mic_pos_frame[:, 1], mic_pos_frame[:, 2]),
149
149
  )
150
150
  for s_idx, line in enumerate(src_lines):
151
151
  xyz = src_frame[:, s_idx, :]
@@ -166,7 +166,10 @@ def animate_scene_gif(
166
166
  fps = frames / duration_s
167
167
  else:
168
168
  fps = 6.0
169
- anim = animation.FuncAnimation(fig, _frame, frames=frames, interval=1000 / fps, blit=False)
170
- anim.save(out_path, writer="pillow", fps=fps)
169
+ anim = animation.FuncAnimation(
170
+ fig, _frame, frames=frames, interval=1000 / fps, blit=False
171
+ )
172
+ fps_int = None if fps is None else max(1, int(round(fps)))
173
+ anim.save(out_path, writer="pillow", fps=fps_int)
171
174
  plt.close(fig)
172
175
  return out_path
torchrir/core.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  """Core RIR simulation functions (static and dynamic)."""
4
4
 
5
5
  import math
6
+ from collections.abc import Callable
6
7
  from typing import Optional, Tuple
7
8
 
8
9
  import torch
@@ -61,8 +62,8 @@ def simulate_rir(
61
62
 
62
63
  Example:
63
64
  >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
64
- >>> sources = Source.positions([[1.0, 2.0, 1.5]])
65
- >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
65
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
66
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
66
67
  >>> rir = simulate_rir(
67
68
  ... room=room,
68
69
  ... sources=sources,
@@ -90,9 +91,9 @@ def simulate_rir(
90
91
 
91
92
  if not isinstance(room, Room):
92
93
  raise TypeError("room must be a Room instance")
93
- if nsample is None and tmax is None:
94
- raise ValueError("nsample or tmax must be provided")
95
94
  if nsample is None:
95
+ if tmax is None:
96
+ raise ValueError("nsample or tmax must be provided")
96
97
  nsample = int(math.ceil(tmax * room.fs))
97
98
  if nsample <= 0:
98
99
  raise ValueError("nsample must be positive")
@@ -261,6 +262,11 @@ def simulate_dynamic_rir(
261
262
 
262
263
  src_traj = as_tensor(src_traj, device=device, dtype=dtype)
263
264
  mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
265
+ device, dtype = infer_device_dtype(
266
+ src_traj, mic_traj, room.size, device=device, dtype=dtype
267
+ )
268
+ src_traj = as_tensor(src_traj, device=device, dtype=dtype)
269
+ mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
264
270
 
265
271
  if src_traj.ndim == 2:
266
272
  src_traj = src_traj.unsqueeze(1)
@@ -273,24 +279,95 @@ def simulate_dynamic_rir(
273
279
  if src_traj.shape[0] != mic_traj.shape[0]:
274
280
  raise ValueError("src_traj and mic_traj must have the same time length")
275
281
 
276
- t_steps = src_traj.shape[0]
277
- rirs = []
278
- for t_idx in range(t_steps):
279
- rir = simulate_rir(
280
- room=room,
281
- sources=src_traj[t_idx],
282
- mics=mic_traj[t_idx],
283
- max_order=max_order,
284
- nsample=nsample,
285
- tmax=tmax,
286
- directivity=directivity,
287
- orientation=orientation,
288
- config=config,
289
- device=device,
290
- dtype=dtype,
282
+ if not isinstance(room, Room):
283
+ raise TypeError("room must be a Room instance")
284
+ if nsample is None:
285
+ if tmax is None:
286
+ raise ValueError("nsample or tmax must be provided")
287
+ nsample = int(math.ceil(tmax * room.fs))
288
+ if nsample <= 0:
289
+ raise ValueError("nsample must be positive")
290
+ if max_order < 0:
291
+ raise ValueError("max_order must be non-negative")
292
+
293
+ room_size = as_tensor(room.size, device=device, dtype=dtype)
294
+ room_size = ensure_dim(room_size)
295
+ dim = room_size.numel()
296
+ if src_traj.shape[2] != dim:
297
+ raise ValueError("src_traj must match room dimension")
298
+ if mic_traj.shape[2] != dim:
299
+ raise ValueError("mic_traj must match room dimension")
300
+
301
+ src_ori = None
302
+ mic_ori = None
303
+ if orientation is not None:
304
+ if isinstance(orientation, (list, tuple)):
305
+ if len(orientation) != 2:
306
+ raise ValueError("orientation tuple must have length 2")
307
+ src_ori, mic_ori = orientation
308
+ else:
309
+ src_ori = orientation
310
+ mic_ori = orientation
311
+ if src_ori is not None:
312
+ src_ori = as_tensor(src_ori, device=device, dtype=dtype)
313
+ if mic_ori is not None:
314
+ mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
315
+
316
+ beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
317
+ beta = _validate_beta(beta, dim)
318
+ n_vec = _image_source_indices(max_order, dim, device=device, nb_img=None)
319
+ refl = _reflection_coefficients(n_vec, beta)
320
+
321
+ src_pattern, mic_pattern = split_directivity(directivity)
322
+ mic_dir = None
323
+ if mic_pattern != "omni":
324
+ if mic_ori is None:
325
+ raise ValueError("mic orientation required for non-omni directivity")
326
+ mic_dir = orientation_to_unit(mic_ori, dim)
327
+
328
+ n_src = src_traj.shape[1]
329
+ n_mic = mic_traj.shape[1]
330
+ rirs = torch.zeros((src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype)
331
+ fdl = cfg.frac_delay_length
332
+ fdl2 = (fdl - 1) // 2
333
+ img_chunk = cfg.image_chunk_size
334
+ if img_chunk <= 0:
335
+ img_chunk = n_vec.shape[0]
336
+
337
+ src_dirs = None
338
+ if src_pattern != "omni":
339
+ if src_ori is None:
340
+ raise ValueError("source orientation required for non-omni directivity")
341
+ src_dirs = orientation_to_unit(src_ori, dim)
342
+ if src_dirs.ndim == 1:
343
+ src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
344
+ if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
345
+ raise ValueError("source orientation must match number of sources")
346
+
347
+ for start in range(0, n_vec.shape[0], img_chunk):
348
+ end = min(start + img_chunk, n_vec.shape[0])
349
+ n_vec_chunk = n_vec[start:end]
350
+ refl_chunk = refl[start:end]
351
+ sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
352
+ src_traj,
353
+ mic_traj,
354
+ room_size,
355
+ n_vec_chunk,
356
+ refl_chunk,
357
+ room,
358
+ fdl2,
359
+ src_pattern=src_pattern,
360
+ mic_pattern=mic_pattern,
361
+ src_dirs=src_dirs,
362
+ mic_dir=mic_dir,
291
363
  )
292
- rirs.append(rir)
293
- return torch.stack(rirs, dim=0)
364
+ t_steps = src_traj.shape[0]
365
+ sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
366
+ attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
367
+ rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
368
+ _accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)
369
+
370
+ return rirs
294
371
 
295
372
 
296
373
  def _prepare_entities(
@@ -495,7 +572,11 @@ def _compute_image_contributions_batch(
495
572
  if mic_pattern != "omni":
496
573
  if mic_dir is None:
497
574
  raise ValueError("mic orientation required for non-omni directivity")
498
- mic_dir = mic_dir[None, :, None, :] if mic_dir.ndim == 2 else mic_dir.view(1, 1, 1, -1)
575
+ mic_dir = (
576
+ mic_dir[None, :, None, :]
577
+ if mic_dir.ndim == 2
578
+ else mic_dir.view(1, 1, 1, -1)
579
+ )
499
580
  cos_theta = _cos_between(-vec, mic_dir)
500
581
  gain = gain * directivity_gain(mic_pattern, cos_theta)
501
582
 
@@ -503,6 +584,54 @@ def _compute_image_contributions_batch(
503
584
  return sample, attenuation
504
585
 
505
586
 
587
+ def _compute_image_contributions_time_batch(
588
+ src_traj: Tensor,
589
+ mic_traj: Tensor,
590
+ room_size: Tensor,
591
+ n_vec: Tensor,
592
+ refl: Tensor,
593
+ room: Room,
594
+ fdl2: int,
595
+ *,
596
+ src_pattern: str,
597
+ mic_pattern: str,
598
+ src_dirs: Optional[Tensor],
599
+ mic_dir: Optional[Tensor],
600
+ ) -> Tuple[Tensor, Tensor]:
601
+ """Compute samples/attenuation for all time steps in batch."""
602
+ sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_traj.dtype)
603
+ n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_traj.dtype)
604
+ base = 2.0 * room_size * n
605
+ img = base[None, None, :, :] + sign[None, None, :, :] * src_traj[:, :, None, :]
606
+ vec = mic_traj[:, None, :, None, :] - img[:, :, None, :, :]
607
+ dist = torch.linalg.norm(vec, dim=-1)
608
+ dist = torch.clamp(dist, min=1e-6)
609
+ time = dist / room.c
610
+ time = time + (fdl2 / room.fs)
611
+ sample = time * room.fs
612
+
613
+ gain = refl.view(1, 1, 1, -1)
614
+ if src_pattern != "omni":
615
+ if src_dirs is None:
616
+ raise ValueError("source orientation required for non-omni directivity")
617
+ src_dirs_b = src_dirs[None, :, None, None, :]
618
+ cos_theta = _cos_between(vec, src_dirs_b)
619
+ gain = gain * directivity_gain(src_pattern, cos_theta)
620
+ if mic_pattern != "omni":
621
+ if mic_dir is None:
622
+ raise ValueError("mic orientation required for non-omni directivity")
623
+ mic_dir_b = (
624
+ mic_dir[None, None, :, None, :]
625
+ if mic_dir.ndim == 2
626
+ else mic_dir.view(1, 1, 1, 1, -1)
627
+ )
628
+ cos_theta = _cos_between(-vec, mic_dir_b)
629
+ gain = gain * directivity_gain(mic_pattern, cos_theta)
630
+
631
+ attenuation = gain / dist
632
+ return sample, attenuation
633
+
634
+
506
635
  def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
507
636
  """Pick the correct orientation vector for a given entity index."""
508
637
  if orientation.ndim == 0:
@@ -542,9 +671,9 @@ def _accumulate_rir(
542
671
  if use_lut:
543
672
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
544
673
 
545
- mic_offsets = (torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample).view(
546
- n_mic, 1, 1
547
- )
674
+ mic_offsets = (
675
+ torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample
676
+ ).view(n_mic, 1, 1)
548
677
  rir_flat = rir.view(-1)
549
678
 
550
679
  chunk_size = cfg.accumulate_chunk_size
@@ -559,7 +688,9 @@ def _accumulate_rir(
559
688
  x_off_frac = (1.0 - frac_m) * lut_gran
560
689
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
561
690
  x_off = x_off_frac - lut_gran_off.to(dtype)
562
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
691
+ lut_pos = lut_gran_off[..., None] + (
692
+ n[None, None, :].to(torch.int64) * lut_gran
693
+ )
563
694
 
564
695
  s0 = torch.take(sinc_lut, lut_pos)
565
696
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -618,9 +749,9 @@ def _accumulate_rir_batch_impl(
618
749
  if use_lut:
619
750
  sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
620
751
 
621
- sm_offsets = (torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample).view(
622
- n_sm, 1, 1
623
- )
752
+ sm_offsets = (
753
+ torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample
754
+ ).view(n_sm, 1, 1)
624
755
  rir_flat = rir.view(-1)
625
756
 
626
757
  n_img = idx0.shape[1]
@@ -634,7 +765,9 @@ def _accumulate_rir_batch_impl(
634
765
  x_off_frac = (1.0 - frac_m) * lut_gran
635
766
  lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
636
767
  x_off = x_off_frac - lut_gran_off.to(sample.dtype)
637
- lut_pos = lut_gran_off[..., None] + (n[None, None, :].to(torch.int64) * lut_gran)
768
+ lut_pos = lut_gran_off[..., None] + (
769
+ n[None, None, :].to(torch.int64) * lut_gran
770
+ )
638
771
 
639
772
  s0 = torch.take(sinc_lut, lut_pos)
640
773
  s1 = torch.take(sinc_lut, lut_pos + 1)
@@ -660,12 +793,13 @@ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
660
793
  _FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
661
794
  _FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
662
795
  _FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
663
- _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], callable] = {}
796
+ _AccumFn = Callable[[Tensor, Tensor, Tensor], None]
797
+ _ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], _AccumFn] = {}
664
798
 
665
799
 
666
800
  def _get_accumulate_fn(
667
801
  cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
668
- ) -> callable:
802
+ ) -> _AccumFn:
669
803
  """Return an accumulation function with config-bound constants."""
670
804
  use_lut = cfg.use_lut and device.type != "mps"
671
805
  fdl = cfg.frac_delay_length
@@ -721,7 +855,9 @@ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Te
721
855
  return cached
722
856
 
723
857
 
724
- def _get_sinc_lut(fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype) -> Tensor:
858
+ def _get_sinc_lut(
859
+ fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype
860
+ ) -> Tensor:
725
861
  """Create a sinc lookup table for fractional delays."""
726
862
  key = (fdl, lut_gran, str(device), dtype)
727
863
  cached = _SINC_LUT_CACHE.get(key)
@@ -765,7 +901,12 @@ def _apply_diffuse_tail(
765
901
 
766
902
  gen = torch.Generator(device=rir.device)
767
903
  gen.manual_seed(0 if seed is None else seed)
768
- noise = torch.randn(rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen)
769
- scale = torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True) + 1e-8
904
+ noise = torch.randn(
905
+ rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen
906
+ )
907
+ scale = (
908
+ torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True)
909
+ + 1e-8
910
+ )
770
911
  rir[..., tdiff_idx:] = noise * decay * scale
771
912
  return rir
@@ -1,14 +1,15 @@
1
1
  """Dataset helpers for torchrir."""
2
2
 
3
- from .base import BaseDataset, SentenceLike
4
- from .utils import choose_speakers, load_dataset_sources
3
+ from .base import BaseDataset, DatasetItem, SentenceLike
4
+ from .utils import choose_speakers, load_dataset_sources, load_wav_mono
5
+ from .collate import CollateBatch, collate_dataset_items
5
6
  from .template import TemplateDataset, TemplateSentence
7
+ from .librispeech import LibriSpeechDataset, LibriSpeechSentence
6
8
 
7
9
  from .cmu_arctic import (
8
10
  CmuArcticDataset,
9
11
  CmuArcticSentence,
10
12
  list_cmu_arctic_speakers,
11
- load_wav_mono,
12
13
  save_wav,
13
14
  )
14
15
 
@@ -17,6 +18,9 @@ __all__ = [
17
18
  "CmuArcticDataset",
18
19
  "CmuArcticSentence",
19
20
  "choose_speakers",
21
+ "DatasetItem",
22
+ "CollateBatch",
23
+ "collate_dataset_items",
20
24
  "list_cmu_arctic_speakers",
21
25
  "SentenceLike",
22
26
  "load_dataset_sources",
@@ -24,4 +28,6 @@ __all__ = [
24
28
  "save_wav",
25
29
  "TemplateDataset",
26
30
  "TemplateSentence",
31
+ "LibriSpeechDataset",
32
+ "LibriSpeechSentence",
27
33
  ]
torchrir/datasets/base.py CHANGED
@@ -2,9 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  """Dataset protocol definitions."""
4
4
 
5
- from typing import Protocol, Sequence, Tuple
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Protocol, Sequence, Tuple
6
7
 
7
8
  import torch
9
+ from torch.utils.data import Dataset
8
10
 
9
11
 
10
12
  class SentenceLike(Protocol):
@@ -14,14 +16,52 @@ class SentenceLike(Protocol):
14
16
  text: str
15
17
 
16
18
 
17
- class BaseDataset(Protocol):
18
- """Protocol for datasets used in torchrir examples and tools."""
19
+ @dataclass(frozen=True)
20
+ class DatasetItem:
21
+ """Dataset item for DataLoader consumption."""
22
+
23
+ audio: torch.Tensor
24
+ sample_rate: int
25
+ utterance_id: str
26
+ text: Optional[str] = None
27
+ speaker: Optional[str] = None
28
+
29
+
30
+ class BaseDataset(Dataset[DatasetItem]):
31
+ """Base dataset class compatible with torch.utils.data.Dataset."""
32
+
33
+ _sentences_cache: Optional[list[SentenceLike]] = None
19
34
 
20
35
  def list_speakers(self) -> list[str]:
21
36
  """Return available speaker IDs."""
37
+ raise NotImplementedError
22
38
 
23
39
  def available_sentences(self) -> Sequence[SentenceLike]:
24
40
  """Return sentence entries that have audio available."""
41
+ raise NotImplementedError
25
42
 
26
43
  def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
27
44
  """Load audio for an utterance and return (audio, sample_rate)."""
45
+ raise NotImplementedError
46
+
47
+ def __len__(self) -> int:
48
+ return len(self._get_sentences())
49
+
50
+ def __getitem__(self, idx: int) -> DatasetItem:
51
+ sentences = self._get_sentences()
52
+ sentence = sentences[idx]
53
+ audio, sample_rate = self.load_wav(sentence.utterance_id)
54
+ speaker = getattr(self, "speaker", None)
55
+ text = getattr(sentence, "text", None)
56
+ return DatasetItem(
57
+ audio=audio,
58
+ sample_rate=sample_rate,
59
+ utterance_id=sentence.utterance_id,
60
+ text=text,
61
+ speaker=speaker,
62
+ )
63
+
64
+ def _get_sentences(self) -> list[SentenceLike]:
65
+ if self._sentences_cache is None:
66
+ self._sentences_cache = list(self.available_sentences())
67
+ return self._sentences_cache
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """CMU ARCTIC dataset helpers."""
4
4
 
5
+ import logging
5
6
  import tarfile
6
7
  import urllib.request
7
8
  from dataclasses import dataclass
@@ -9,7 +10,9 @@ from pathlib import Path
9
10
  from typing import List, Tuple
10
11
 
11
12
  import torch
12
- import logging
13
+
14
+ from .base import BaseDataset
15
+ from .utils import load_wav_mono
13
16
 
14
17
  BASE_URL = "http://www.festvox.org/cmu_arctic/packed"
15
18
  VALID_SPEAKERS = {
@@ -44,11 +47,12 @@ def list_cmu_arctic_speakers() -> List[str]:
44
47
  @dataclass
45
48
  class CmuArcticSentence:
46
49
  """Sentence metadata from CMU ARCTIC."""
50
+
47
51
  utterance_id: str
48
52
  text: str
49
53
 
50
54
 
51
- class CmuArcticDataset:
55
+ class CmuArcticDataset(BaseDataset):
52
56
  """CMU ARCTIC dataset loader.
53
57
 
54
58
  Example:
@@ -56,7 +60,9 @@ class CmuArcticDataset:
56
60
  >>> audio, fs = dataset.load_wav("arctic_a0001")
57
61
  """
58
62
 
59
- def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
63
+ def __init__(
64
+ self, root: Path, speaker: str = "bdl", download: bool = False
65
+ ) -> None:
60
66
  """Initialize a CMU ARCTIC dataset handle.
61
67
 
62
68
  Args:
@@ -188,23 +194,6 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
188
194
  return utterance, text
189
195
 
190
196
 
191
- def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
192
- """Load a wav file and return mono audio and sample rate.
193
-
194
- Example:
195
- >>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
196
- """
197
- import soundfile as sf
198
-
199
- audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
200
- audio_t = torch.from_numpy(audio)
201
- if audio_t.shape[1] > 1:
202
- audio_t = audio_t.mean(dim=1)
203
- else:
204
- audio_t = audio_t.squeeze(1)
205
- return audio_t, sample_rate
206
-
207
-
208
197
  def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
209
198
  """Save a mono or multi-channel wav to disk.
210
199
 
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ """Collate helpers for DataLoader usage."""
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Iterable, List, Optional
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ from .base import DatasetItem
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class CollateBatch:
16
+ """Collated batch of dataset items.
17
+
18
+ Attributes:
19
+ audio: Padded audio tensor of shape (batch, max_len).
20
+ lengths: Original lengths for each item.
21
+ sample_rate: Sample rate shared across the batch.
22
+ utterance_ids: Utterance IDs per item.
23
+ texts: Optional text per item.
24
+ speakers: Optional speaker IDs per item.
25
+ metadata: Optional per-item metadata (pass-through).
26
+ """
27
+
28
+ audio: Tensor
29
+ lengths: Tensor
30
+ sample_rate: int
31
+ utterance_ids: list[str]
32
+ texts: list[Optional[str]]
33
+ speakers: list[Optional[str]]
34
+ metadata: Optional[list[Any]] = None
35
+
36
+
37
+ def collate_dataset_items(
38
+ items: Iterable[DatasetItem],
39
+ *,
40
+ pad_value: float = 0.0,
41
+ keep_metadata: bool = False,
42
+ ) -> CollateBatch:
43
+ """Collate DatasetItem entries into a padded batch.
44
+
45
+ Args:
46
+ items: Iterable of DatasetItem.
47
+ pad_value: Value used for padding.
48
+ keep_metadata: Preserve item-level metadata field if present.
49
+
50
+ Returns:
51
+ CollateBatch with padded audio and metadata lists.
52
+ """
53
+ batch = list(items)
54
+ if not batch:
55
+ raise ValueError("collate_dataset_items received an empty batch")
56
+
57
+ sample_rate = batch[0].sample_rate
58
+ for item in batch[1:]:
59
+ if item.sample_rate != sample_rate:
60
+ raise ValueError("sample_rate must be consistent within a batch")
61
+
62
+ lengths = torch.tensor([item.audio.numel() for item in batch], dtype=torch.long)
63
+ max_len = int(lengths.max().item())
64
+ audio = torch.full(
65
+ (len(batch), max_len),
66
+ pad_value,
67
+ dtype=batch[0].audio.dtype,
68
+ device=batch[0].audio.device,
69
+ )
70
+
71
+ for idx, item in enumerate(batch):
72
+ audio[idx, : item.audio.numel()] = item.audio
73
+
74
+ utterance_ids = [item.utterance_id for item in batch]
75
+ texts = [item.text for item in batch]
76
+ speakers = [item.speaker for item in batch]
77
+
78
+ metadata: Optional[list[Any]] = None
79
+ if keep_metadata:
80
+ metadata = [getattr(item, "metadata", None) for item in batch]
81
+
82
+ return CollateBatch(
83
+ audio=audio,
84
+ lengths=lengths,
85
+ sample_rate=sample_rate,
86
+ utterance_ids=utterance_ids,
87
+ texts=texts,
88
+ speakers=speakers,
89
+ metadata=metadata,
90
+ )