torchrir 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrir/__init__.py +16 -1
- torchrir/animation.py +17 -14
- torchrir/core.py +176 -35
- torchrir/datasets/__init__.py +9 -3
- torchrir/datasets/base.py +43 -3
- torchrir/datasets/cmu_arctic.py +9 -20
- torchrir/datasets/collate.py +90 -0
- torchrir/datasets/librispeech.py +175 -0
- torchrir/datasets/template.py +3 -1
- torchrir/datasets/utils.py +23 -1
- torchrir/dynamic.py +3 -1
- torchrir/plotting.py +13 -6
- torchrir/plotting_utils.py +4 -1
- torchrir/room.py +2 -38
- torchrir/scene_utils.py +6 -2
- torchrir/signal.py +24 -10
- torchrir/simulators.py +12 -4
- torchrir/utils.py +1 -1
- torchrir-0.2.0.dist-info/METADATA +70 -0
- torchrir-0.2.0.dist-info/RECORD +30 -0
- torchrir-0.1.2.dist-info/METADATA +0 -271
- torchrir-0.1.2.dist-info/RECORD +0 -28
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/WHEEL +0 -0
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/licenses/NOTICE +0 -0
- {torchrir-0.1.2.dist-info → torchrir-0.2.0.dist-info}/top_level.txt +0 -0
torchrir/__init__.py
CHANGED
|
@@ -18,6 +18,11 @@ from .datasets import (
|
|
|
18
18
|
CmuArcticDataset,
|
|
19
19
|
CmuArcticSentence,
|
|
20
20
|
choose_speakers,
|
|
21
|
+
CollateBatch,
|
|
22
|
+
collate_dataset_items,
|
|
23
|
+
DatasetItem,
|
|
24
|
+
LibriSpeechDataset,
|
|
25
|
+
LibriSpeechSentence,
|
|
21
26
|
list_cmu_arctic_speakers,
|
|
22
27
|
SentenceLike,
|
|
23
28
|
load_dataset_sources,
|
|
@@ -26,7 +31,12 @@ from .datasets import (
|
|
|
26
31
|
load_wav_mono,
|
|
27
32
|
save_wav,
|
|
28
33
|
)
|
|
29
|
-
from .scene_utils import
|
|
34
|
+
from .scene_utils import (
|
|
35
|
+
binaural_mic_positions,
|
|
36
|
+
clamp_positions,
|
|
37
|
+
linear_trajectory,
|
|
38
|
+
sample_positions,
|
|
39
|
+
)
|
|
30
40
|
from .utils import (
|
|
31
41
|
att2t_SabineEstimation,
|
|
32
42
|
att2t_sabine_estimation,
|
|
@@ -56,6 +66,11 @@ __all__ = [
|
|
|
56
66
|
"CmuArcticDataset",
|
|
57
67
|
"CmuArcticSentence",
|
|
58
68
|
"choose_speakers",
|
|
69
|
+
"CollateBatch",
|
|
70
|
+
"collate_dataset_items",
|
|
71
|
+
"DatasetItem",
|
|
72
|
+
"LibriSpeechDataset",
|
|
73
|
+
"LibriSpeechSentence",
|
|
59
74
|
"DynamicConvolver",
|
|
60
75
|
"estimate_beta_from_t60",
|
|
61
76
|
"estimate_t60_from_beta",
|
torchrir/animation.py
CHANGED
|
@@ -104,15 +104,15 @@ def animate_scene_gif(
|
|
|
104
104
|
mic_lines = []
|
|
105
105
|
for _ in range(view_src_traj.shape[1]):
|
|
106
106
|
if view_dim == 2:
|
|
107
|
-
line, = ax.plot([], [], color="tab:green", alpha=0.6)
|
|
107
|
+
(line,) = ax.plot([], [], color="tab:green", alpha=0.6)
|
|
108
108
|
else:
|
|
109
|
-
line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
|
|
109
|
+
(line,) = ax.plot([], [], [], color="tab:green", alpha=0.6)
|
|
110
110
|
src_lines.append(line)
|
|
111
111
|
for _ in range(view_mic_traj.shape[1]):
|
|
112
112
|
if view_dim == 2:
|
|
113
|
-
line, = ax.plot([], [], color="tab:orange", alpha=0.6)
|
|
113
|
+
(line,) = ax.plot([], [], color="tab:orange", alpha=0.6)
|
|
114
114
|
else:
|
|
115
|
-
line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
|
|
115
|
+
(line,) = ax.plot([], [], [], color="tab:orange", alpha=0.6)
|
|
116
116
|
mic_lines.append(line)
|
|
117
117
|
|
|
118
118
|
ax.legend(loc="best")
|
|
@@ -137,15 +137,15 @@ def animate_scene_gif(
|
|
|
137
137
|
xy = mic_frame[:, m_idx, :]
|
|
138
138
|
line.set_data(xy[:, 0], xy[:, 1])
|
|
139
139
|
else:
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
src_pos_frame[:, 2],
|
|
140
|
+
setattr(
|
|
141
|
+
src_scatter,
|
|
142
|
+
"_offsets3d",
|
|
143
|
+
(src_pos_frame[:, 0], src_pos_frame[:, 1], src_pos_frame[:, 2]),
|
|
144
144
|
)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
mic_pos_frame[:, 2],
|
|
145
|
+
setattr(
|
|
146
|
+
mic_scatter,
|
|
147
|
+
"_offsets3d",
|
|
148
|
+
(mic_pos_frame[:, 0], mic_pos_frame[:, 1], mic_pos_frame[:, 2]),
|
|
149
149
|
)
|
|
150
150
|
for s_idx, line in enumerate(src_lines):
|
|
151
151
|
xyz = src_frame[:, s_idx, :]
|
|
@@ -166,7 +166,10 @@ def animate_scene_gif(
|
|
|
166
166
|
fps = frames / duration_s
|
|
167
167
|
else:
|
|
168
168
|
fps = 6.0
|
|
169
|
-
anim = animation.FuncAnimation(
|
|
170
|
-
|
|
169
|
+
anim = animation.FuncAnimation(
|
|
170
|
+
fig, _frame, frames=frames, interval=1000 / fps, blit=False
|
|
171
|
+
)
|
|
172
|
+
fps_int = None if fps is None else max(1, int(round(fps)))
|
|
173
|
+
anim.save(out_path, writer="pillow", fps=fps_int)
|
|
171
174
|
plt.close(fig)
|
|
172
175
|
return out_path
|
torchrir/core.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
"""Core RIR simulation functions (static and dynamic)."""
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from typing import Optional, Tuple
|
|
7
8
|
|
|
8
9
|
import torch
|
|
@@ -61,8 +62,8 @@ def simulate_rir(
|
|
|
61
62
|
|
|
62
63
|
Example:
|
|
63
64
|
>>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
|
|
64
|
-
>>> sources = Source.
|
|
65
|
-
>>> mics = MicrophoneArray.
|
|
65
|
+
>>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
|
|
66
|
+
>>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
|
|
66
67
|
>>> rir = simulate_rir(
|
|
67
68
|
... room=room,
|
|
68
69
|
... sources=sources,
|
|
@@ -90,9 +91,9 @@ def simulate_rir(
|
|
|
90
91
|
|
|
91
92
|
if not isinstance(room, Room):
|
|
92
93
|
raise TypeError("room must be a Room instance")
|
|
93
|
-
if nsample is None and tmax is None:
|
|
94
|
-
raise ValueError("nsample or tmax must be provided")
|
|
95
94
|
if nsample is None:
|
|
95
|
+
if tmax is None:
|
|
96
|
+
raise ValueError("nsample or tmax must be provided")
|
|
96
97
|
nsample = int(math.ceil(tmax * room.fs))
|
|
97
98
|
if nsample <= 0:
|
|
98
99
|
raise ValueError("nsample must be positive")
|
|
@@ -261,6 +262,11 @@ def simulate_dynamic_rir(
|
|
|
261
262
|
|
|
262
263
|
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
263
264
|
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
265
|
+
device, dtype = infer_device_dtype(
|
|
266
|
+
src_traj, mic_traj, room.size, device=device, dtype=dtype
|
|
267
|
+
)
|
|
268
|
+
src_traj = as_tensor(src_traj, device=device, dtype=dtype)
|
|
269
|
+
mic_traj = as_tensor(mic_traj, device=device, dtype=dtype)
|
|
264
270
|
|
|
265
271
|
if src_traj.ndim == 2:
|
|
266
272
|
src_traj = src_traj.unsqueeze(1)
|
|
@@ -273,24 +279,95 @@ def simulate_dynamic_rir(
|
|
|
273
279
|
if src_traj.shape[0] != mic_traj.shape[0]:
|
|
274
280
|
raise ValueError("src_traj and mic_traj must have the same time length")
|
|
275
281
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
282
|
+
if not isinstance(room, Room):
|
|
283
|
+
raise TypeError("room must be a Room instance")
|
|
284
|
+
if nsample is None:
|
|
285
|
+
if tmax is None:
|
|
286
|
+
raise ValueError("nsample or tmax must be provided")
|
|
287
|
+
nsample = int(math.ceil(tmax * room.fs))
|
|
288
|
+
if nsample <= 0:
|
|
289
|
+
raise ValueError("nsample must be positive")
|
|
290
|
+
if max_order < 0:
|
|
291
|
+
raise ValueError("max_order must be non-negative")
|
|
292
|
+
|
|
293
|
+
room_size = as_tensor(room.size, device=device, dtype=dtype)
|
|
294
|
+
room_size = ensure_dim(room_size)
|
|
295
|
+
dim = room_size.numel()
|
|
296
|
+
if src_traj.shape[2] != dim:
|
|
297
|
+
raise ValueError("src_traj must match room dimension")
|
|
298
|
+
if mic_traj.shape[2] != dim:
|
|
299
|
+
raise ValueError("mic_traj must match room dimension")
|
|
300
|
+
|
|
301
|
+
src_ori = None
|
|
302
|
+
mic_ori = None
|
|
303
|
+
if orientation is not None:
|
|
304
|
+
if isinstance(orientation, (list, tuple)):
|
|
305
|
+
if len(orientation) != 2:
|
|
306
|
+
raise ValueError("orientation tuple must have length 2")
|
|
307
|
+
src_ori, mic_ori = orientation
|
|
308
|
+
else:
|
|
309
|
+
src_ori = orientation
|
|
310
|
+
mic_ori = orientation
|
|
311
|
+
if src_ori is not None:
|
|
312
|
+
src_ori = as_tensor(src_ori, device=device, dtype=dtype)
|
|
313
|
+
if mic_ori is not None:
|
|
314
|
+
mic_ori = as_tensor(mic_ori, device=device, dtype=dtype)
|
|
315
|
+
|
|
316
|
+
beta = _resolve_beta(room, room_size, device=device, dtype=dtype)
|
|
317
|
+
beta = _validate_beta(beta, dim)
|
|
318
|
+
n_vec = _image_source_indices(max_order, dim, device=device, nb_img=None)
|
|
319
|
+
refl = _reflection_coefficients(n_vec, beta)
|
|
320
|
+
|
|
321
|
+
src_pattern, mic_pattern = split_directivity(directivity)
|
|
322
|
+
mic_dir = None
|
|
323
|
+
if mic_pattern != "omni":
|
|
324
|
+
if mic_ori is None:
|
|
325
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
326
|
+
mic_dir = orientation_to_unit(mic_ori, dim)
|
|
327
|
+
|
|
328
|
+
n_src = src_traj.shape[1]
|
|
329
|
+
n_mic = mic_traj.shape[1]
|
|
330
|
+
rirs = torch.zeros((src_traj.shape[0], n_src, n_mic, nsample), device=device, dtype=dtype)
|
|
331
|
+
fdl = cfg.frac_delay_length
|
|
332
|
+
fdl2 = (fdl - 1) // 2
|
|
333
|
+
img_chunk = cfg.image_chunk_size
|
|
334
|
+
if img_chunk <= 0:
|
|
335
|
+
img_chunk = n_vec.shape[0]
|
|
336
|
+
|
|
337
|
+
src_dirs = None
|
|
338
|
+
if src_pattern != "omni":
|
|
339
|
+
if src_ori is None:
|
|
340
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
341
|
+
src_dirs = orientation_to_unit(src_ori, dim)
|
|
342
|
+
if src_dirs.ndim == 1:
|
|
343
|
+
src_dirs = src_dirs.unsqueeze(0).repeat(n_src, 1)
|
|
344
|
+
if src_dirs.ndim != 2 or src_dirs.shape[0] != n_src:
|
|
345
|
+
raise ValueError("source orientation must match number of sources")
|
|
346
|
+
|
|
347
|
+
for start in range(0, n_vec.shape[0], img_chunk):
|
|
348
|
+
end = min(start + img_chunk, n_vec.shape[0])
|
|
349
|
+
n_vec_chunk = n_vec[start:end]
|
|
350
|
+
refl_chunk = refl[start:end]
|
|
351
|
+
sample_chunk, attenuation_chunk = _compute_image_contributions_time_batch(
|
|
352
|
+
src_traj,
|
|
353
|
+
mic_traj,
|
|
354
|
+
room_size,
|
|
355
|
+
n_vec_chunk,
|
|
356
|
+
refl_chunk,
|
|
357
|
+
room,
|
|
358
|
+
fdl2,
|
|
359
|
+
src_pattern=src_pattern,
|
|
360
|
+
mic_pattern=mic_pattern,
|
|
361
|
+
src_dirs=src_dirs,
|
|
362
|
+
mic_dir=mic_dir,
|
|
291
363
|
)
|
|
292
|
-
|
|
293
|
-
|
|
364
|
+
t_steps = src_traj.shape[0]
|
|
365
|
+
sample_flat = sample_chunk.reshape(t_steps * n_src, n_mic, -1)
|
|
366
|
+
attenuation_flat = attenuation_chunk.reshape(t_steps * n_src, n_mic, -1)
|
|
367
|
+
rir_flat = rirs.view(t_steps * n_src, n_mic, nsample)
|
|
368
|
+
_accumulate_rir_batch(rir_flat, sample_flat, attenuation_flat, cfg)
|
|
369
|
+
|
|
370
|
+
return rirs
|
|
294
371
|
|
|
295
372
|
|
|
296
373
|
def _prepare_entities(
|
|
@@ -495,7 +572,11 @@ def _compute_image_contributions_batch(
|
|
|
495
572
|
if mic_pattern != "omni":
|
|
496
573
|
if mic_dir is None:
|
|
497
574
|
raise ValueError("mic orientation required for non-omni directivity")
|
|
498
|
-
mic_dir =
|
|
575
|
+
mic_dir = (
|
|
576
|
+
mic_dir[None, :, None, :]
|
|
577
|
+
if mic_dir.ndim == 2
|
|
578
|
+
else mic_dir.view(1, 1, 1, -1)
|
|
579
|
+
)
|
|
499
580
|
cos_theta = _cos_between(-vec, mic_dir)
|
|
500
581
|
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
501
582
|
|
|
@@ -503,6 +584,54 @@ def _compute_image_contributions_batch(
|
|
|
503
584
|
return sample, attenuation
|
|
504
585
|
|
|
505
586
|
|
|
587
|
+
def _compute_image_contributions_time_batch(
|
|
588
|
+
src_traj: Tensor,
|
|
589
|
+
mic_traj: Tensor,
|
|
590
|
+
room_size: Tensor,
|
|
591
|
+
n_vec: Tensor,
|
|
592
|
+
refl: Tensor,
|
|
593
|
+
room: Room,
|
|
594
|
+
fdl2: int,
|
|
595
|
+
*,
|
|
596
|
+
src_pattern: str,
|
|
597
|
+
mic_pattern: str,
|
|
598
|
+
src_dirs: Optional[Tensor],
|
|
599
|
+
mic_dir: Optional[Tensor],
|
|
600
|
+
) -> Tuple[Tensor, Tensor]:
|
|
601
|
+
"""Compute samples/attenuation for all time steps in batch."""
|
|
602
|
+
sign = torch.where((n_vec % 2) == 0, 1.0, -1.0).to(dtype=src_traj.dtype)
|
|
603
|
+
n = torch.floor_divide(n_vec + 1, 2).to(dtype=src_traj.dtype)
|
|
604
|
+
base = 2.0 * room_size * n
|
|
605
|
+
img = base[None, None, :, :] + sign[None, None, :, :] * src_traj[:, :, None, :]
|
|
606
|
+
vec = mic_traj[:, None, :, None, :] - img[:, :, None, :, :]
|
|
607
|
+
dist = torch.linalg.norm(vec, dim=-1)
|
|
608
|
+
dist = torch.clamp(dist, min=1e-6)
|
|
609
|
+
time = dist / room.c
|
|
610
|
+
time = time + (fdl2 / room.fs)
|
|
611
|
+
sample = time * room.fs
|
|
612
|
+
|
|
613
|
+
gain = refl.view(1, 1, 1, -1)
|
|
614
|
+
if src_pattern != "omni":
|
|
615
|
+
if src_dirs is None:
|
|
616
|
+
raise ValueError("source orientation required for non-omni directivity")
|
|
617
|
+
src_dirs_b = src_dirs[None, :, None, None, :]
|
|
618
|
+
cos_theta = _cos_between(vec, src_dirs_b)
|
|
619
|
+
gain = gain * directivity_gain(src_pattern, cos_theta)
|
|
620
|
+
if mic_pattern != "omni":
|
|
621
|
+
if mic_dir is None:
|
|
622
|
+
raise ValueError("mic orientation required for non-omni directivity")
|
|
623
|
+
mic_dir_b = (
|
|
624
|
+
mic_dir[None, None, :, None, :]
|
|
625
|
+
if mic_dir.ndim == 2
|
|
626
|
+
else mic_dir.view(1, 1, 1, 1, -1)
|
|
627
|
+
)
|
|
628
|
+
cos_theta = _cos_between(-vec, mic_dir_b)
|
|
629
|
+
gain = gain * directivity_gain(mic_pattern, cos_theta)
|
|
630
|
+
|
|
631
|
+
attenuation = gain / dist
|
|
632
|
+
return sample, attenuation
|
|
633
|
+
|
|
634
|
+
|
|
506
635
|
def _select_orientation(orientation: Tensor, idx: int, count: int, dim: int) -> Tensor:
|
|
507
636
|
"""Pick the correct orientation vector for a given entity index."""
|
|
508
637
|
if orientation.ndim == 0:
|
|
@@ -542,9 +671,9 @@ def _accumulate_rir(
|
|
|
542
671
|
if use_lut:
|
|
543
672
|
sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=dtype)
|
|
544
673
|
|
|
545
|
-
mic_offsets = (
|
|
546
|
-
n_mic,
|
|
547
|
-
)
|
|
674
|
+
mic_offsets = (
|
|
675
|
+
torch.arange(n_mic, device=rir.device, dtype=torch.int64) * nsample
|
|
676
|
+
).view(n_mic, 1, 1)
|
|
548
677
|
rir_flat = rir.view(-1)
|
|
549
678
|
|
|
550
679
|
chunk_size = cfg.accumulate_chunk_size
|
|
@@ -559,7 +688,9 @@ def _accumulate_rir(
|
|
|
559
688
|
x_off_frac = (1.0 - frac_m) * lut_gran
|
|
560
689
|
lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
|
|
561
690
|
x_off = x_off_frac - lut_gran_off.to(dtype)
|
|
562
|
-
lut_pos = lut_gran_off[..., None] + (
|
|
691
|
+
lut_pos = lut_gran_off[..., None] + (
|
|
692
|
+
n[None, None, :].to(torch.int64) * lut_gran
|
|
693
|
+
)
|
|
563
694
|
|
|
564
695
|
s0 = torch.take(sinc_lut, lut_pos)
|
|
565
696
|
s1 = torch.take(sinc_lut, lut_pos + 1)
|
|
@@ -618,9 +749,9 @@ def _accumulate_rir_batch_impl(
|
|
|
618
749
|
if use_lut:
|
|
619
750
|
sinc_lut = _get_sinc_lut(fdl, lut_gran, device=rir.device, dtype=sample.dtype)
|
|
620
751
|
|
|
621
|
-
sm_offsets = (
|
|
622
|
-
n_sm,
|
|
623
|
-
)
|
|
752
|
+
sm_offsets = (
|
|
753
|
+
torch.arange(n_sm, device=rir.device, dtype=torch.int64) * nsample
|
|
754
|
+
).view(n_sm, 1, 1)
|
|
624
755
|
rir_flat = rir.view(-1)
|
|
625
756
|
|
|
626
757
|
n_img = idx0.shape[1]
|
|
@@ -634,7 +765,9 @@ def _accumulate_rir_batch_impl(
|
|
|
634
765
|
x_off_frac = (1.0 - frac_m) * lut_gran
|
|
635
766
|
lut_gran_off = torch.floor(x_off_frac).to(torch.int64)
|
|
636
767
|
x_off = x_off_frac - lut_gran_off.to(sample.dtype)
|
|
637
|
-
lut_pos = lut_gran_off[..., None] + (
|
|
768
|
+
lut_pos = lut_gran_off[..., None] + (
|
|
769
|
+
n[None, None, :].to(torch.int64) * lut_gran
|
|
770
|
+
)
|
|
638
771
|
|
|
639
772
|
s0 = torch.take(sinc_lut, lut_pos)
|
|
640
773
|
s1 = torch.take(sinc_lut, lut_pos + 1)
|
|
@@ -660,12 +793,13 @@ _SINC_LUT_CACHE: dict[tuple[int, int, str, torch.dtype], Tensor] = {}
|
|
|
660
793
|
_FDL_GRID_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
|
|
661
794
|
_FDL_OFFSETS_CACHE: dict[tuple[int, str], Tensor] = {}
|
|
662
795
|
_FDL_WINDOW_CACHE: dict[tuple[int, str, torch.dtype], Tensor] = {}
|
|
663
|
-
|
|
796
|
+
_AccumFn = Callable[[Tensor, Tensor, Tensor], None]
|
|
797
|
+
_ACCUM_BATCH_COMPILED: dict[tuple[str, torch.dtype, int, int, bool, int], _AccumFn] = {}
|
|
664
798
|
|
|
665
799
|
|
|
666
800
|
def _get_accumulate_fn(
|
|
667
801
|
cfg: SimulationConfig, device: torch.device, dtype: torch.dtype
|
|
668
|
-
) ->
|
|
802
|
+
) -> _AccumFn:
|
|
669
803
|
"""Return an accumulation function with config-bound constants."""
|
|
670
804
|
use_lut = cfg.use_lut and device.type != "mps"
|
|
671
805
|
fdl = cfg.frac_delay_length
|
|
@@ -721,7 +855,9 @@ def _get_fdl_window(fdl: int, *, device: torch.device, dtype: torch.dtype) -> Te
|
|
|
721
855
|
return cached
|
|
722
856
|
|
|
723
857
|
|
|
724
|
-
def _get_sinc_lut(
|
|
858
|
+
def _get_sinc_lut(
|
|
859
|
+
fdl: int, lut_gran: int, *, device: torch.device, dtype: torch.dtype
|
|
860
|
+
) -> Tensor:
|
|
725
861
|
"""Create a sinc lookup table for fractional delays."""
|
|
726
862
|
key = (fdl, lut_gran, str(device), dtype)
|
|
727
863
|
cached = _SINC_LUT_CACHE.get(key)
|
|
@@ -765,7 +901,12 @@ def _apply_diffuse_tail(
|
|
|
765
901
|
|
|
766
902
|
gen = torch.Generator(device=rir.device)
|
|
767
903
|
gen.manual_seed(0 if seed is None else seed)
|
|
768
|
-
noise = torch.randn(
|
|
769
|
-
|
|
904
|
+
noise = torch.randn(
|
|
905
|
+
rir[..., tdiff_idx:].shape, device=rir.device, dtype=rir.dtype, generator=gen
|
|
906
|
+
)
|
|
907
|
+
scale = (
|
|
908
|
+
torch.linalg.norm(rir[..., tdiff_idx - 1 : tdiff_idx], dim=-1, keepdim=True)
|
|
909
|
+
+ 1e-8
|
|
910
|
+
)
|
|
770
911
|
rir[..., tdiff_idx:] = noise * decay * scale
|
|
771
912
|
return rir
|
torchrir/datasets/__init__.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""Dataset helpers for torchrir."""
|
|
2
2
|
|
|
3
|
-
from .base import BaseDataset, SentenceLike
|
|
4
|
-
from .utils import choose_speakers, load_dataset_sources
|
|
3
|
+
from .base import BaseDataset, DatasetItem, SentenceLike
|
|
4
|
+
from .utils import choose_speakers, load_dataset_sources, load_wav_mono
|
|
5
|
+
from .collate import CollateBatch, collate_dataset_items
|
|
5
6
|
from .template import TemplateDataset, TemplateSentence
|
|
7
|
+
from .librispeech import LibriSpeechDataset, LibriSpeechSentence
|
|
6
8
|
|
|
7
9
|
from .cmu_arctic import (
|
|
8
10
|
CmuArcticDataset,
|
|
9
11
|
CmuArcticSentence,
|
|
10
12
|
list_cmu_arctic_speakers,
|
|
11
|
-
load_wav_mono,
|
|
12
13
|
save_wav,
|
|
13
14
|
)
|
|
14
15
|
|
|
@@ -17,6 +18,9 @@ __all__ = [
|
|
|
17
18
|
"CmuArcticDataset",
|
|
18
19
|
"CmuArcticSentence",
|
|
19
20
|
"choose_speakers",
|
|
21
|
+
"DatasetItem",
|
|
22
|
+
"CollateBatch",
|
|
23
|
+
"collate_dataset_items",
|
|
20
24
|
"list_cmu_arctic_speakers",
|
|
21
25
|
"SentenceLike",
|
|
22
26
|
"load_dataset_sources",
|
|
@@ -24,4 +28,6 @@ __all__ = [
|
|
|
24
28
|
"save_wav",
|
|
25
29
|
"TemplateDataset",
|
|
26
30
|
"TemplateSentence",
|
|
31
|
+
"LibriSpeechDataset",
|
|
32
|
+
"LibriSpeechSentence",
|
|
27
33
|
]
|
torchrir/datasets/base.py
CHANGED
|
@@ -2,9 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
"""Dataset protocol definitions."""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional, Protocol, Sequence, Tuple
|
|
6
7
|
|
|
7
8
|
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class SentenceLike(Protocol):
|
|
@@ -14,14 +16,52 @@ class SentenceLike(Protocol):
|
|
|
14
16
|
text: str
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class DatasetItem:
|
|
21
|
+
"""Dataset item for DataLoader consumption."""
|
|
22
|
+
|
|
23
|
+
audio: torch.Tensor
|
|
24
|
+
sample_rate: int
|
|
25
|
+
utterance_id: str
|
|
26
|
+
text: Optional[str] = None
|
|
27
|
+
speaker: Optional[str] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseDataset(Dataset[DatasetItem]):
|
|
31
|
+
"""Base dataset class compatible with torch.utils.data.Dataset."""
|
|
32
|
+
|
|
33
|
+
_sentences_cache: Optional[list[SentenceLike]] = None
|
|
19
34
|
|
|
20
35
|
def list_speakers(self) -> list[str]:
|
|
21
36
|
"""Return available speaker IDs."""
|
|
37
|
+
raise NotImplementedError
|
|
22
38
|
|
|
23
39
|
def available_sentences(self) -> Sequence[SentenceLike]:
|
|
24
40
|
"""Return sentence entries that have audio available."""
|
|
41
|
+
raise NotImplementedError
|
|
25
42
|
|
|
26
43
|
def load_wav(self, utterance_id: str) -> Tuple[torch.Tensor, int]:
|
|
27
44
|
"""Load audio for an utterance and return (audio, sample_rate)."""
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
def __len__(self) -> int:
|
|
48
|
+
return len(self._get_sentences())
|
|
49
|
+
|
|
50
|
+
def __getitem__(self, idx: int) -> DatasetItem:
|
|
51
|
+
sentences = self._get_sentences()
|
|
52
|
+
sentence = sentences[idx]
|
|
53
|
+
audio, sample_rate = self.load_wav(sentence.utterance_id)
|
|
54
|
+
speaker = getattr(self, "speaker", None)
|
|
55
|
+
text = getattr(sentence, "text", None)
|
|
56
|
+
return DatasetItem(
|
|
57
|
+
audio=audio,
|
|
58
|
+
sample_rate=sample_rate,
|
|
59
|
+
utterance_id=sentence.utterance_id,
|
|
60
|
+
text=text,
|
|
61
|
+
speaker=speaker,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _get_sentences(self) -> list[SentenceLike]:
|
|
65
|
+
if self._sentences_cache is None:
|
|
66
|
+
self._sentences_cache = list(self.available_sentences())
|
|
67
|
+
return self._sentences_cache
|
torchrir/datasets/cmu_arctic.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
"""CMU ARCTIC dataset helpers."""
|
|
4
4
|
|
|
5
|
+
import logging
|
|
5
6
|
import tarfile
|
|
6
7
|
import urllib.request
|
|
7
8
|
from dataclasses import dataclass
|
|
@@ -9,7 +10,9 @@ from pathlib import Path
|
|
|
9
10
|
from typing import List, Tuple
|
|
10
11
|
|
|
11
12
|
import torch
|
|
12
|
-
|
|
13
|
+
|
|
14
|
+
from .base import BaseDataset
|
|
15
|
+
from .utils import load_wav_mono
|
|
13
16
|
|
|
14
17
|
BASE_URL = "http://www.festvox.org/cmu_arctic/packed"
|
|
15
18
|
VALID_SPEAKERS = {
|
|
@@ -44,11 +47,12 @@ def list_cmu_arctic_speakers() -> List[str]:
|
|
|
44
47
|
@dataclass
|
|
45
48
|
class CmuArcticSentence:
|
|
46
49
|
"""Sentence metadata from CMU ARCTIC."""
|
|
50
|
+
|
|
47
51
|
utterance_id: str
|
|
48
52
|
text: str
|
|
49
53
|
|
|
50
54
|
|
|
51
|
-
class CmuArcticDataset:
|
|
55
|
+
class CmuArcticDataset(BaseDataset):
|
|
52
56
|
"""CMU ARCTIC dataset loader.
|
|
53
57
|
|
|
54
58
|
Example:
|
|
@@ -56,7 +60,9 @@ class CmuArcticDataset:
|
|
|
56
60
|
>>> audio, fs = dataset.load_wav("arctic_a0001")
|
|
57
61
|
"""
|
|
58
62
|
|
|
59
|
-
def __init__(
|
|
63
|
+
def __init__(
|
|
64
|
+
self, root: Path, speaker: str = "bdl", download: bool = False
|
|
65
|
+
) -> None:
|
|
60
66
|
"""Initialize a CMU ARCTIC dataset handle.
|
|
61
67
|
|
|
62
68
|
Args:
|
|
@@ -188,23 +194,6 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
|
|
|
188
194
|
return utterance, text
|
|
189
195
|
|
|
190
196
|
|
|
191
|
-
def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
|
|
192
|
-
"""Load a wav file and return mono audio and sample rate.
|
|
193
|
-
|
|
194
|
-
Example:
|
|
195
|
-
>>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
|
|
196
|
-
"""
|
|
197
|
-
import soundfile as sf
|
|
198
|
-
|
|
199
|
-
audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
|
|
200
|
-
audio_t = torch.from_numpy(audio)
|
|
201
|
-
if audio_t.shape[1] > 1:
|
|
202
|
-
audio_t = audio_t.mean(dim=1)
|
|
203
|
-
else:
|
|
204
|
-
audio_t = audio_t.squeeze(1)
|
|
205
|
-
return audio_t, sample_rate
|
|
206
|
-
|
|
207
|
-
|
|
208
197
|
def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
|
|
209
198
|
"""Save a mono or multi-channel wav to disk.
|
|
210
199
|
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Collate helpers for DataLoader usage."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from .base import DatasetItem
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class CollateBatch:
|
|
16
|
+
"""Collated batch of dataset items.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
audio: Padded audio tensor of shape (batch, max_len).
|
|
20
|
+
lengths: Original lengths for each item.
|
|
21
|
+
sample_rate: Sample rate shared across the batch.
|
|
22
|
+
utterance_ids: Utterance IDs per item.
|
|
23
|
+
texts: Optional text per item.
|
|
24
|
+
speakers: Optional speaker IDs per item.
|
|
25
|
+
metadata: Optional per-item metadata (pass-through).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
audio: Tensor
|
|
29
|
+
lengths: Tensor
|
|
30
|
+
sample_rate: int
|
|
31
|
+
utterance_ids: list[str]
|
|
32
|
+
texts: list[Optional[str]]
|
|
33
|
+
speakers: list[Optional[str]]
|
|
34
|
+
metadata: Optional[list[Any]] = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def collate_dataset_items(
|
|
38
|
+
items: Iterable[DatasetItem],
|
|
39
|
+
*,
|
|
40
|
+
pad_value: float = 0.0,
|
|
41
|
+
keep_metadata: bool = False,
|
|
42
|
+
) -> CollateBatch:
|
|
43
|
+
"""Collate DatasetItem entries into a padded batch.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
items: Iterable of DatasetItem.
|
|
47
|
+
pad_value: Value used for padding.
|
|
48
|
+
keep_metadata: Preserve item-level metadata field if present.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
CollateBatch with padded audio and metadata lists.
|
|
52
|
+
"""
|
|
53
|
+
batch = list(items)
|
|
54
|
+
if not batch:
|
|
55
|
+
raise ValueError("collate_dataset_items received an empty batch")
|
|
56
|
+
|
|
57
|
+
sample_rate = batch[0].sample_rate
|
|
58
|
+
for item in batch[1:]:
|
|
59
|
+
if item.sample_rate != sample_rate:
|
|
60
|
+
raise ValueError("sample_rate must be consistent within a batch")
|
|
61
|
+
|
|
62
|
+
lengths = torch.tensor([item.audio.numel() for item in batch], dtype=torch.long)
|
|
63
|
+
max_len = int(lengths.max().item())
|
|
64
|
+
audio = torch.full(
|
|
65
|
+
(len(batch), max_len),
|
|
66
|
+
pad_value,
|
|
67
|
+
dtype=batch[0].audio.dtype,
|
|
68
|
+
device=batch[0].audio.device,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
for idx, item in enumerate(batch):
|
|
72
|
+
audio[idx, : item.audio.numel()] = item.audio
|
|
73
|
+
|
|
74
|
+
utterance_ids = [item.utterance_id for item in batch]
|
|
75
|
+
texts = [item.text for item in batch]
|
|
76
|
+
speakers = [item.speaker for item in batch]
|
|
77
|
+
|
|
78
|
+
metadata: Optional[list[Any]] = None
|
|
79
|
+
if keep_metadata:
|
|
80
|
+
metadata = [getattr(item, "metadata", None) for item in batch]
|
|
81
|
+
|
|
82
|
+
return CollateBatch(
|
|
83
|
+
audio=audio,
|
|
84
|
+
lengths=lengths,
|
|
85
|
+
sample_rate=sample_rate,
|
|
86
|
+
utterance_ids=utterance_ids,
|
|
87
|
+
texts=texts,
|
|
88
|
+
speakers=speakers,
|
|
89
|
+
metadata=metadata,
|
|
90
|
+
)
|