torchrir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/__init__.py CHANGED
@@ -4,6 +4,8 @@ from .config import SimulationConfig, default_config
4
4
  from .core import simulate_dynamic_rir, simulate_rir
5
5
  from .dynamic import DynamicConvolver
6
6
  from .logging_utils import LoggingConfig, get_logger, setup_logging
7
+ from .animation import animate_scene_gif
8
+ from .metadata import build_metadata, save_metadata_json
7
9
  from .plotting import plot_scene_dynamic, plot_scene_static
8
10
  from .plotting_utils import plot_scene_and_save
9
11
  from .room import MicrophoneArray, Room, Source
@@ -61,6 +63,8 @@ __all__ = [
61
63
  "get_logger",
62
64
  "list_cmu_arctic_speakers",
63
65
  "LoggingConfig",
66
+ "animate_scene_gif",
67
+ "build_metadata",
64
68
  "resolve_device",
65
69
  "SentenceLike",
66
70
  "load_dataset_sources",
@@ -75,6 +79,7 @@ __all__ = [
75
79
  "plot_scene_and_save",
76
80
  "plot_scene_static",
77
81
  "save_wav",
82
+ "save_metadata_json",
78
83
  "Scene",
79
84
  "setup_logging",
80
85
  "SimulationConfig",
torchrir/animation.py ADDED
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ """Animation helpers for dynamic scenes."""
4
+
5
+ from pathlib import Path
6
+ from typing import Optional, Sequence
7
+
8
+ import torch
9
+
10
+ from .plotting_utils import _positions_to_cpu, _to_cpu, _traj_steps, _trajectory_to_cpu
11
+
12
+
13
+ def animate_scene_gif(
14
+ *,
15
+ out_path: Path,
16
+ room: Sequence[float] | torch.Tensor,
17
+ sources: object | torch.Tensor | Sequence,
18
+ mics: object | torch.Tensor | Sequence,
19
+ src_traj: Optional[torch.Tensor | Sequence] = None,
20
+ mic_traj: Optional[torch.Tensor | Sequence] = None,
21
+ step: int = 1,
22
+ fps: Optional[float] = None,
23
+ signal_len: Optional[int] = None,
24
+ fs: Optional[float] = None,
25
+ duration_s: Optional[float] = None,
26
+ plot_2d: bool = True,
27
+ plot_3d: bool = False,
28
+ ) -> Path:
29
+ """Render a GIF showing source/mic trajectories.
30
+
31
+ Args:
32
+ out_path: Destination GIF path.
33
+ room: Room size tensor or sequence.
34
+ sources: Source positions or Source-like object.
35
+ mics: Microphone positions or MicrophoneArray-like object.
36
+ src_traj: Optional source trajectory (T, n_src, dim).
37
+ mic_traj: Optional mic trajectory (T, n_mic, dim).
38
+ step: Subsampling step for trajectories.
39
+ fps: Frames per second for the GIF (auto if None).
40
+ signal_len: Optional signal length (samples) to infer elapsed time.
41
+ fs: Sample rate used with signal_len.
42
+ duration_s: Optional total duration in seconds (overrides signal_len/fs).
43
+ plot_2d: Use 2D projection if True.
44
+ plot_3d: Use 3D projection if True and dim == 3.
45
+
46
+ Returns:
47
+ The output path.
48
+
49
+ Example:
50
+ >>> animate_scene_gif(
51
+ ... out_path=Path("outputs/scene.gif"),
52
+ ... room=[6.0, 4.0, 3.0],
53
+ ... sources=[[1.0, 2.0, 1.5]],
54
+ ... mics=[[2.0, 2.0, 1.5]],
55
+ ... src_traj=src_traj,
56
+ ... mic_traj=mic_traj,
57
+ ... signal_len=16000,
58
+ ... fs=16000,
59
+ ... )
60
+ """
61
+ import matplotlib.pyplot as plt
62
+ from matplotlib import animation
63
+
64
+ out_path = Path(out_path)
65
+ out_path.parent.mkdir(parents=True, exist_ok=True)
66
+
67
+ room_size = _to_cpu(room)
68
+ src_pos = _positions_to_cpu(sources)
69
+ mic_pos = _positions_to_cpu(mics)
70
+ dim = int(room_size.numel())
71
+ view_dim = 3 if (plot_3d and dim == 3) else 2
72
+ view_room = room_size[:view_dim]
73
+ view_src = src_pos[:, :view_dim]
74
+ view_mic = mic_pos[:, :view_dim]
75
+
76
+ if src_traj is None and mic_traj is None:
77
+ raise ValueError("at least one trajectory is required for animation")
78
+ steps = _traj_steps(src_traj, mic_traj)
79
+ src_traj = _trajectory_to_cpu(src_traj, src_pos, steps)
80
+ mic_traj = _trajectory_to_cpu(mic_traj, mic_pos, steps)
81
+ view_src_traj = src_traj[:, :, :view_dim]
82
+ view_mic_traj = mic_traj[:, :, :view_dim]
83
+
84
+ if view_dim == 3:
85
+ fig = plt.figure()
86
+ ax = fig.add_subplot(111, projection="3d")
87
+ ax.set_xlim(0, view_room[0].item())
88
+ ax.set_ylim(0, view_room[1].item())
89
+ ax.set_zlim(0, view_room[2].item())
90
+ ax.set_xlabel("x")
91
+ ax.set_ylabel("y")
92
+ ax.set_zlabel("z")
93
+ else:
94
+ fig, ax = plt.subplots()
95
+ ax.set_xlim(0, view_room[0].item())
96
+ ax.set_ylim(0, view_room[1].item())
97
+ ax.set_aspect("equal", adjustable="box")
98
+ ax.set_xlabel("x")
99
+ ax.set_ylabel("y")
100
+
101
+ src_scatter = ax.scatter([], [], marker="^", color="tab:green", label="sources")
102
+ mic_scatter = ax.scatter([], [], marker="o", color="tab:orange", label="mics")
103
+ src_lines = []
104
+ mic_lines = []
105
+ for _ in range(view_src_traj.shape[1]):
106
+ if view_dim == 2:
107
+ line, = ax.plot([], [], color="tab:green", alpha=0.6)
108
+ else:
109
+ line, = ax.plot([], [], [], color="tab:green", alpha=0.6)
110
+ src_lines.append(line)
111
+ for _ in range(view_mic_traj.shape[1]):
112
+ if view_dim == 2:
113
+ line, = ax.plot([], [], color="tab:orange", alpha=0.6)
114
+ else:
115
+ line, = ax.plot([], [], [], color="tab:orange", alpha=0.6)
116
+ mic_lines.append(line)
117
+
118
+ ax.legend(loc="best")
119
+
120
+ if duration_s is None and signal_len is not None and fs is not None:
121
+ duration_s = float(signal_len) / float(fs)
122
+
123
+ def _frame(i: int):
124
+ idx = min(i * step, view_src_traj.shape[0] - 1)
125
+ src_frame = view_src_traj[: idx + 1]
126
+ mic_frame = view_mic_traj[: idx + 1]
127
+ src_pos_frame = view_src_traj[idx]
128
+ mic_pos_frame = view_mic_traj[idx]
129
+
130
+ if view_dim == 2:
131
+ src_scatter.set_offsets(src_pos_frame)
132
+ mic_scatter.set_offsets(mic_pos_frame)
133
+ for s_idx, line in enumerate(src_lines):
134
+ xy = src_frame[:, s_idx, :]
135
+ line.set_data(xy[:, 0], xy[:, 1])
136
+ for m_idx, line in enumerate(mic_lines):
137
+ xy = mic_frame[:, m_idx, :]
138
+ line.set_data(xy[:, 0], xy[:, 1])
139
+ else:
140
+ src_scatter._offsets3d = (
141
+ src_pos_frame[:, 0],
142
+ src_pos_frame[:, 1],
143
+ src_pos_frame[:, 2],
144
+ )
145
+ mic_scatter._offsets3d = (
146
+ mic_pos_frame[:, 0],
147
+ mic_pos_frame[:, 1],
148
+ mic_pos_frame[:, 2],
149
+ )
150
+ for s_idx, line in enumerate(src_lines):
151
+ xyz = src_frame[:, s_idx, :]
152
+ line.set_data(xyz[:, 0], xyz[:, 1])
153
+ line.set_3d_properties(xyz[:, 2])
154
+ for m_idx, line in enumerate(mic_lines):
155
+ xyz = mic_frame[:, m_idx, :]
156
+ line.set_data(xyz[:, 0], xyz[:, 1])
157
+ line.set_3d_properties(xyz[:, 2])
158
+ if duration_s is not None and steps > 1:
159
+ t = (idx / (steps - 1)) * duration_s
160
+ ax.set_title(f"t = {t:.2f} s")
161
+ return [src_scatter, mic_scatter, *src_lines, *mic_lines]
162
+
163
+ frames = max(1, (view_src_traj.shape[0] + step - 1) // step)
164
+ if fps is None or fps <= 0:
165
+ if duration_s is not None and duration_s > 0:
166
+ fps = frames / duration_s
167
+ else:
168
+ fps = 6.0
169
+ anim = animation.FuncAnimation(fig, _frame, frames=frames, interval=1000 / fps, blit=False)
170
+ anim.save(out_path, writer="pillow", fps=fps)
171
+ plt.close(fig)
172
+ return out_path
torchrir/config.py CHANGED
@@ -10,7 +10,12 @@ import torch
10
10
 
11
11
  @dataclass(frozen=True)
12
12
  class SimulationConfig:
13
- """Configuration values for RIR simulation and convolution."""
13
+ """Configuration values for RIR simulation and convolution.
14
+
15
+ Example:
16
+ >>> cfg = SimulationConfig(max_order=6, tmax=0.3, device="auto")
17
+ >>> cfg.validate()
18
+ """
14
19
 
15
20
  fs: Optional[float] = None
16
21
  max_order: Optional[int] = None
@@ -53,7 +58,11 @@ class SimulationConfig:
53
58
 
54
59
 
55
60
  def default_config() -> SimulationConfig:
56
- """Return the default simulation configuration."""
61
+ """Return the default simulation configuration.
62
+
63
+ Example:
64
+ >>> cfg = default_config()
65
+ """
57
66
  cfg = SimulationConfig()
58
67
  cfg.validate()
59
68
  return cfg
torchrir/core.py CHANGED
@@ -58,6 +58,18 @@ def simulate_rir(
58
58
 
59
59
  Returns:
60
60
  Tensor of shape (n_src, n_mic, nsample).
61
+
62
+ Example:
63
+ >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
64
+ >>> sources = Source.positions([[1.0, 2.0, 1.5]])
65
+ >>> mics = MicrophoneArray.positions([[2.0, 2.0, 1.5]])
66
+ >>> rir = simulate_rir(
67
+ ... room=room,
68
+ ... sources=sources,
69
+ ... mics=mics,
70
+ ... max_order=6,
71
+ ... tmax=0.3,
72
+ ... )
61
73
  """
62
74
  cfg = config or default_config()
63
75
  cfg.validate()
@@ -208,6 +220,24 @@ def simulate_dynamic_rir(
208
220
 
209
221
  Returns:
210
222
  Tensor of shape (T, n_src, n_mic, nsample).
223
+
224
+ Example:
225
+ >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
226
+ >>> from torchrir import linear_trajectory
227
+ >>> src_traj = torch.stack(
228
+ ... [linear_trajectory(torch.tensor([1.0, 2.0, 1.5]),
229
+ ... torch.tensor([4.0, 2.0, 1.5]), 8)],
230
+ ... dim=1,
231
+ ... )
232
+ >>> mic_pos = torch.tensor([[2.0, 2.0, 1.5]])
233
+ >>> mic_traj = mic_pos.unsqueeze(0).repeat(8, 1, 1)
234
+ >>> rirs = simulate_dynamic_rir(
235
+ ... room=room,
236
+ ... src_traj=src_traj,
237
+ ... mic_traj=mic_traj,
238
+ ... max_order=4,
239
+ ... tmax=0.3,
240
+ ... )
211
241
  """
212
242
  cfg = config or default_config()
213
243
  cfg.validate()
@@ -49,6 +49,13 @@ class CmuArcticSentence:
49
49
 
50
50
 
51
51
  class CmuArcticDataset:
52
+ """CMU ARCTIC dataset loader.
53
+
54
+ Example:
55
+ >>> dataset = CmuArcticDataset(Path("datasets/cmu_arctic"), speaker="bdl", download=True)
56
+ >>> audio, fs = dataset.load_wav("arctic_a0001")
57
+ """
58
+
52
59
  def __init__(self, root: Path, speaker: str = "bdl", download: bool = False) -> None:
53
60
  """Initialize a CMU ARCTIC dataset handle.
54
61
 
@@ -182,7 +189,11 @@ def _parse_text_line(line: str) -> Tuple[str, str]:
182
189
 
183
190
 
184
191
  def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
185
- """Load a wav file and return mono audio and sample rate."""
192
+ """Load a wav file and return mono audio and sample rate.
193
+
194
+ Example:
195
+ >>> audio, fs = load_wav_mono(Path("datasets/cmu_arctic/ARCTIC/.../wav/arctic_a0001.wav"))
196
+ """
186
197
  import soundfile as sf
187
198
 
188
199
  audio, sample_rate = sf.read(str(path), dtype="float32", always_2d=True)
@@ -195,7 +206,11 @@ def load_wav_mono(path: Path) -> Tuple[torch.Tensor, int]:
195
206
 
196
207
 
197
208
  def save_wav(path: Path, audio: torch.Tensor, sample_rate: int) -> None:
198
- """Save a mono or multi-channel wav to disk."""
209
+ """Save a mono or multi-channel wav to disk.
210
+
211
+ Example:
212
+ >>> save_wav(Path("outputs/example.wav"), audio, sample_rate)
213
+ """
199
214
  import soundfile as sf
200
215
 
201
216
  audio = audio.detach().cpu().clamp(-1.0, 1.0).to(torch.float32)
@@ -11,7 +11,12 @@ from .base import BaseDataset, SentenceLike
11
11
 
12
12
 
13
13
  def choose_speakers(dataset: BaseDataset, num_sources: int, rng: random.Random) -> List[str]:
14
- """Select unique speakers for the requested number of sources."""
14
+ """Select unique speakers for the requested number of sources.
15
+
16
+ Example:
17
+ >>> rng = random.Random(0)
18
+ >>> speakers = choose_speakers(dataset, num_sources=2, rng=rng)
19
+ """
15
20
  speakers = dataset.list_speakers()
16
21
  if not speakers:
17
22
  raise RuntimeError("no speakers available")
@@ -27,7 +32,20 @@ def load_dataset_sources(
27
32
  duration_s: float,
28
33
  rng: random.Random,
29
34
  ) -> Tuple[torch.Tensor, int, List[Tuple[str, List[str]]]]:
30
- """Load and concatenate utterances for each speaker into fixed-length signals."""
35
+ """Load and concatenate utterances for each speaker into fixed-length signals.
36
+
37
+ Example:
38
+ >>> from pathlib import Path
39
+ >>> from torchrir import CmuArcticDataset
40
+ >>> rng = random.Random(0)
41
+ >>> root = Path("datasets/cmu_arctic")
42
+ >>> signals, fs, info = load_dataset_sources(
43
+ ... dataset_factory=lambda spk: CmuArcticDataset(root, speaker=spk, download=True),
44
+ ... num_sources=2,
45
+ ... duration_s=10.0,
46
+ ... rng=rng,
47
+ ... )
48
+ """
31
49
  dataset0 = dataset_factory(None)
32
50
  speakers = choose_speakers(dataset0, num_sources, rng)
33
51
  signals: List[torch.Tensor] = []
torchrir/dynamic.py CHANGED
@@ -17,7 +17,12 @@ from .signal import _ensure_dynamic_rirs, _ensure_signal
17
17
 
18
18
  @dataclass(frozen=True)
19
19
  class DynamicConvolver:
20
- """Convolver for time-varying RIRs."""
20
+ """Convolver for time-varying RIRs.
21
+
22
+ Example:
23
+ >>> convolver = DynamicConvolver(mode="trajectory")
24
+ >>> y = convolver.convolve(signal, rirs)
25
+ """
21
26
 
22
27
  mode: str = "trajectory"
23
28
  hop: Optional[int] = None
@@ -28,7 +33,11 @@ class DynamicConvolver:
28
33
  return self.convolve(signal, rirs)
29
34
 
30
35
  def convolve(self, signal: Tensor, rirs: Tensor) -> Tensor:
31
- """Convolve signals with time-varying RIRs."""
36
+ """Convolve signals with time-varying RIRs.
37
+
38
+ Example:
39
+ >>> y = DynamicConvolver(mode="hop", hop=1024).convolve(signal, rirs)
40
+ """
32
41
  if self.mode not in ("trajectory", "hop"):
33
42
  raise ValueError("mode must be 'trajectory' or 'hop'")
34
43
  if self.mode == "hop":
torchrir/logging_utils.py CHANGED
@@ -9,7 +9,12 @@ from typing import Optional
9
9
 
10
10
  @dataclass(frozen=True)
11
11
  class LoggingConfig:
12
- """Configuration for torchrir logging."""
12
+ """Configuration for torchrir logging.
13
+
14
+ Example:
15
+ >>> config = LoggingConfig(level="INFO")
16
+ >>> logger = setup_logging(config)
17
+ """
13
18
 
14
19
  level: str | int = "INFO"
15
20
  format: str = "%(levelname)s:%(name)s:%(message)s"
@@ -33,7 +38,12 @@ class LoggingConfig:
33
38
 
34
39
 
35
40
  def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.Logger:
36
- """Configure and return the base torchrir logger."""
41
+ """Configure and return the base torchrir logger.
42
+
43
+ Example:
44
+ >>> logger = setup_logging(LoggingConfig(level="DEBUG"))
45
+ >>> logger.info("ready")
46
+ """
37
47
  logger = logging.getLogger(name)
38
48
  level = config.resolve_level()
39
49
  logger.setLevel(level)
@@ -47,7 +57,11 @@ def setup_logging(config: LoggingConfig, *, name: str = "torchrir") -> logging.L
47
57
 
48
58
 
49
59
  def get_logger(name: Optional[str] = None) -> logging.Logger:
50
- """Return a torchrir logger, namespaced under the torchrir root."""
60
+ """Return a torchrir logger, namespaced under the torchrir root.
61
+
62
+ Example:
63
+ >>> logger = get_logger("examples.static")
64
+ """
51
65
  if not name:
52
66
  return logging.getLogger("torchrir")
53
67
  if name.startswith("torchrir"):
torchrir/metadata.py ADDED
@@ -0,0 +1,216 @@
1
+ from __future__ import annotations
2
+
3
+ """Metadata helpers for simulation outputs."""
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional
8
+
9
+ import json
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from .room import MicrophoneArray, Room, Source
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ArrayAttributes:
18
+ """Structured description of a microphone array."""
19
+
20
+ geometry_name: str
21
+ positions: Tensor
22
+ orientation: Optional[Tensor]
23
+ center: Tensor
24
+ normal: Optional[Tensor]
25
+ spacing: Optional[float]
26
+
27
+
28
+ def build_metadata(
29
+ *,
30
+ room: Room,
31
+ sources: Source,
32
+ mics: MicrophoneArray,
33
+ rirs: Tensor,
34
+ src_traj: Optional[Tensor] = None,
35
+ mic_traj: Optional[Tensor] = None,
36
+ timestamps: Optional[Tensor] = None,
37
+ signal_len: Optional[int] = None,
38
+ source_info: Optional[Any] = None,
39
+ extra: Optional[Dict[str, Any]] = None,
40
+ ) -> Dict[str, Any]:
41
+ """Build JSON-serializable metadata for a simulation output.
42
+
43
+ Example:
44
+ >>> metadata = build_metadata(
45
+ ... room=room,
46
+ ... sources=sources,
47
+ ... mics=mics,
48
+ ... rirs=rirs,
49
+ ... src_traj=src_traj,
50
+ ... mic_traj=mic_traj,
51
+ ... signal_len=signal.shape[-1],
52
+ ... )
53
+ >>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
54
+ """
55
+ nsample = int(rirs.shape[-1])
56
+ fs = float(room.fs)
57
+ time_axis = {
58
+ "fs": fs,
59
+ "nsample": nsample,
60
+ "t": _to_serializable(torch.arange(nsample, dtype=torch.float32) / fs),
61
+ }
62
+
63
+ src_pos = sources.positions
64
+ mic_pos = mics.positions
65
+ dim = int(room.size.numel())
66
+ src_traj_n = _normalize_traj(src_traj, src_pos, dim, "src_traj")
67
+ mic_traj_n = _normalize_traj(mic_traj, mic_pos, dim, "mic_traj")
68
+
69
+ t_steps = max(src_traj_n.shape[0], mic_traj_n.shape[0])
70
+ if src_traj_n.shape[0] == 1 and t_steps > 1:
71
+ src_traj_n = src_traj_n.expand(t_steps, -1, -1)
72
+ if mic_traj_n.shape[0] == 1 and t_steps > 1:
73
+ mic_traj_n = mic_traj_n.expand(t_steps, -1, -1)
74
+ if src_traj_n.shape[0] != mic_traj_n.shape[0]:
75
+ raise ValueError("src_traj and mic_traj must have matching time steps")
76
+
77
+ azimuth, elevation = _compute_doa(src_traj_n, mic_traj_n)
78
+ doa = {
79
+ "frame": "world",
80
+ "unit": "radians",
81
+ "azimuth": _to_serializable(azimuth),
82
+ "elevation": _to_serializable(elevation),
83
+ }
84
+
85
+ timestamps_out: Optional[Tensor] = None
86
+ if timestamps is not None:
87
+ timestamps_out = timestamps
88
+ elif t_steps > 1 and signal_len is not None:
89
+ duration = max(0.0, (float(signal_len) - 1.0) / fs)
90
+ timestamps_out = torch.linspace(0.0, duration, t_steps, dtype=torch.float32)
91
+
92
+ array_attrs = _array_attributes(mics)
93
+
94
+ metadata: Dict[str, Any] = {
95
+ "room": {
96
+ "size": _to_serializable(room.size),
97
+ "c": float(room.c),
98
+ "beta": _to_serializable(room.beta) if room.beta is not None else None,
99
+ "t60": float(room.t60) if room.t60 is not None else None,
100
+ "fs": fs,
101
+ },
102
+ "sources": {
103
+ "positions": _to_serializable(src_pos),
104
+ "orientation": _to_serializable(sources.orientation),
105
+ },
106
+ "mics": {
107
+ "positions": _to_serializable(mic_pos),
108
+ "orientation": _to_serializable(mics.orientation),
109
+ },
110
+ "trajectories": {
111
+ "sources": _to_serializable(src_traj_n if t_steps > 1 else None),
112
+ "mics": _to_serializable(mic_traj_n if t_steps > 1 else None),
113
+ },
114
+ "array": {
115
+ "geometry": array_attrs.geometry_name,
116
+ "positions": _to_serializable(array_attrs.positions),
117
+ "orientation": _to_serializable(array_attrs.orientation),
118
+ "center": _to_serializable(array_attrs.center),
119
+ "normal": _to_serializable(array_attrs.normal),
120
+ "spacing": array_attrs.spacing,
121
+ },
122
+ "time_axis": time_axis,
123
+ "doa": doa,
124
+ "timestamps": _to_serializable(timestamps_out),
125
+ "rirs_shape": list(rirs.shape),
126
+ "dynamic": bool(t_steps > 1),
127
+ }
128
+
129
+ if source_info is not None:
130
+ metadata["source_info"] = _to_serializable(source_info)
131
+ if extra:
132
+ metadata["extra"] = _to_serializable(extra)
133
+ return metadata
134
+
135
+
136
+ def save_metadata_json(path: Path, metadata: Dict[str, Any]) -> None:
137
+ """Save metadata as JSON to the given path.
138
+
139
+ Example:
140
+ >>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
141
+ """
142
+ path.parent.mkdir(parents=True, exist_ok=True)
143
+ with path.open("w", encoding="utf-8") as f:
144
+ json.dump(metadata, f, indent=2)
145
+
146
+
147
+ def _normalize_traj(traj: Optional[Tensor], pos: Tensor, dim: int, name: str) -> Tensor:
148
+ if traj is None:
149
+ if pos.ndim != 2 or pos.shape[1] != dim:
150
+ raise ValueError(f"{name} default positions must have shape (N, {dim})")
151
+ return pos.unsqueeze(0)
152
+ if not torch.is_tensor(traj):
153
+ raise TypeError(f"{name} must be a Tensor")
154
+ if traj.ndim == 2:
155
+ if traj.shape[1] != dim:
156
+ raise ValueError(f"{name} must have shape (T, {dim})")
157
+ return traj.unsqueeze(1)
158
+ if traj.ndim == 3:
159
+ if traj.shape[2] != dim:
160
+ raise ValueError(f"{name} must have shape (T, N, {dim})")
161
+ return traj
162
+ raise ValueError(f"{name} must have shape (T, N, {dim})")
163
+
164
+
165
+ def _compute_doa(src_traj: Tensor, mic_traj: Tensor) -> tuple[Tensor, Tensor]:
166
+ vec = src_traj[:, :, None, :] - mic_traj[:, None, :, :]
167
+ x = vec[..., 0]
168
+ y = vec[..., 1]
169
+ azimuth = torch.atan2(y, x)
170
+ if vec.shape[-1] < 3:
171
+ elevation = torch.zeros_like(azimuth)
172
+ else:
173
+ z = vec[..., 2]
174
+ r_xy = torch.sqrt(x**2 + y**2)
175
+ elevation = torch.atan2(z, r_xy)
176
+ return azimuth, elevation
177
+
178
+
179
+ def _array_attributes(mics: MicrophoneArray) -> ArrayAttributes:
180
+ pos = mics.positions
181
+ n_mic = pos.shape[0]
182
+ if n_mic == 1:
183
+ geometry = "single"
184
+ elif n_mic == 2:
185
+ geometry = "binaural"
186
+ else:
187
+ geometry = "custom"
188
+ center = pos.mean(dim=0)
189
+ spacing = None
190
+ if n_mic >= 2:
191
+ dists = torch.cdist(pos, pos)
192
+ dists = dists[dists > 0]
193
+ if dists.numel() > 0:
194
+ spacing = float(dists.min().item())
195
+ return ArrayAttributes(
196
+ geometry_name=geometry,
197
+ positions=pos,
198
+ orientation=mics.orientation,
199
+ center=center,
200
+ normal=None,
201
+ spacing=spacing,
202
+ )
203
+
204
+
205
+ def _to_serializable(value: Any) -> Any:
206
+ if value is None:
207
+ return None
208
+ if torch.is_tensor(value):
209
+ return value.detach().cpu().tolist()
210
+ if isinstance(value, Path):
211
+ return str(value)
212
+ if isinstance(value, dict):
213
+ return {k: _to_serializable(v) for k, v in value.items()}
214
+ if isinstance(value, (list, tuple)):
215
+ return [_to_serializable(v) for v in value]
216
+ return value