torchrir 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/plotting.py ADDED
@@ -0,0 +1,210 @@
1
+ from __future__ import annotations
2
+
3
+ """Matplotlib-based plotting helpers for room scenes."""
4
+
5
+ from typing import Any, Iterable, Optional, Sequence, Tuple
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .room import MicrophoneArray, Room, Source
11
+ from .utils import as_tensor, ensure_dim
12
+
13
+
14
+ def plot_scene_static(
15
+ *,
16
+ room: Room | Sequence[float] | Tensor,
17
+ sources: Source | Tensor | Sequence,
18
+ mics: MicrophoneArray | Tensor | Sequence,
19
+ ax: Any | None = None,
20
+ title: Optional[str] = None,
21
+ show: bool = False,
22
+ ):
23
+ """Plot a static room with source and mic positions."""
24
+ plt, ax = _setup_axes(ax, room)
25
+
26
+ size = _room_size(room, ax)
27
+ _draw_room(ax, size)
28
+
29
+ src = _extract_positions(sources, ax)
30
+ mic = _extract_positions(mics, ax)
31
+
32
+ _scatter_positions(ax, src, label="sources", marker="^")
33
+ _scatter_positions(ax, mic, label="mics", marker="o")
34
+
35
+ if title:
36
+ ax.set_title(title)
37
+ ax.legend(loc="best")
38
+ if show:
39
+ plt.show()
40
+ return ax
41
+
42
+
43
+ def plot_scene_dynamic(
44
+ *,
45
+ room: Room | Sequence[float] | Tensor,
46
+ src_traj: Tensor | Sequence,
47
+ mic_traj: Tensor | Sequence,
48
+ step: int = 1,
49
+ ax: Any | None = None,
50
+ title: Optional[str] = None,
51
+ show: bool = False,
52
+ ):
53
+ """Plot source and mic trajectories within a room."""
54
+ plt, ax = _setup_axes(ax, room)
55
+
56
+ size = _room_size(room, ax)
57
+ _draw_room(ax, size)
58
+
59
+ src_traj = _as_trajectory(src_traj, ax)
60
+ mic_traj = _as_trajectory(mic_traj, ax)
61
+
62
+ _plot_trajectories(ax, src_traj, step=step, label="source path")
63
+ _plot_trajectories(ax, mic_traj, step=step, label="mic path")
64
+
65
+ if title:
66
+ ax.set_title(title)
67
+ ax.legend(loc="best")
68
+ if show:
69
+ plt.show()
70
+ return ax
71
+
72
+
73
+ def _setup_axes(ax: Any | None, room: Room | Sequence[float] | Tensor) -> tuple[Any, Any]:
74
+ """Create 2D/3D axes based on room dimension."""
75
+ import matplotlib.pyplot as plt
76
+
77
+ size = _room_size(room, ax)
78
+ dim = size.numel()
79
+ if ax is None:
80
+ if dim == 3:
81
+ fig = plt.figure()
82
+ ax = fig.add_subplot(111, projection="3d")
83
+ else:
84
+ _, ax = plt.subplots()
85
+ return plt, ax
86
+
87
+
88
+ def _room_size(room: Room | Sequence[float] | Tensor, ax: Any | None) -> Tensor:
89
+ """Normalize room size input to a 1D tensor."""
90
+ if isinstance(room, Room):
91
+ size = room.size
92
+ else:
93
+ size = room
94
+ size = as_tensor(size)
95
+ size = ensure_dim(size)
96
+ return size
97
+
98
+
99
+ def _draw_room(ax: Any, size: Tensor) -> None:
100
+ """Draw a 2D or 3D room outline."""
101
+ dim = size.numel()
102
+ if dim == 2:
103
+ _draw_room_2d(ax, size)
104
+ else:
105
+ _draw_room_3d(ax, size)
106
+
107
+
108
+ def _draw_room_2d(ax: Any, size: Tensor) -> None:
109
+ """Draw a 2D rectangular room."""
110
+ import matplotlib.patches as patches
111
+
112
+ rect = patches.Rectangle((0.0, 0.0), size[0].item(), size[1].item(),
113
+ fill=False, edgecolor="black")
114
+ ax.add_patch(rect)
115
+ ax.set_xlim(0, size[0].item())
116
+ ax.set_ylim(0, size[1].item())
117
+ ax.set_aspect("equal", adjustable="box")
118
+ ax.set_xlabel("x")
119
+ ax.set_ylabel("y")
120
+
121
+
122
+ def _draw_room_3d(ax: Any, size: Tensor) -> None:
123
+ """Draw a 3D box representing the room."""
124
+ x, y, z = size.tolist()
125
+ corners = torch.tensor(
126
+ [
127
+ [0, 0, 0],
128
+ [x, 0, 0],
129
+ [x, y, 0],
130
+ [0, y, 0],
131
+ [0, 0, z],
132
+ [x, 0, z],
133
+ [x, y, z],
134
+ [0, y, z],
135
+ ],
136
+ dtype=torch.float32,
137
+ )
138
+ edges = [
139
+ (0, 1),
140
+ (1, 2),
141
+ (2, 3),
142
+ (3, 0),
143
+ (4, 5),
144
+ (5, 6),
145
+ (6, 7),
146
+ (7, 4),
147
+ (0, 4),
148
+ (1, 5),
149
+ (2, 6),
150
+ (3, 7),
151
+ ]
152
+ for a, b in edges:
153
+ ax.plot(
154
+ [corners[a, 0], corners[b, 0]],
155
+ [corners[a, 1], corners[b, 1]],
156
+ [corners[a, 2], corners[b, 2]],
157
+ color="black",
158
+ )
159
+ ax.set_xlim(0, x)
160
+ ax.set_ylim(0, y)
161
+ ax.set_zlim(0, z)
162
+ ax.set_xlabel("x")
163
+ ax.set_ylabel("y")
164
+ ax.set_zlabel("z")
165
+
166
+
167
+ def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None) -> Tensor:
168
+ """Extract positions from Source/MicrophoneArray or raw tensor."""
169
+ if isinstance(entity, (Source, MicrophoneArray)):
170
+ pos = entity.positions
171
+ else:
172
+ pos = entity
173
+ pos = as_tensor(pos)
174
+ if pos.ndim == 1:
175
+ pos = pos.unsqueeze(0)
176
+ return pos
177
+
178
+
179
+ def _scatter_positions(ax: Any, positions: Tensor, *, label: str, marker: str) -> None:
180
+ """Scatter-plot positions in 2D or 3D."""
181
+ if positions.numel() == 0:
182
+ return
183
+ dim = positions.shape[1]
184
+ if dim == 2:
185
+ ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker)
186
+ else:
187
+ ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], label=label, marker=marker)
188
+
189
+
190
+ def _as_trajectory(traj: Tensor | Sequence, ax: Any | None) -> Tensor:
191
+ """Validate and normalize a trajectory tensor."""
192
+ traj = as_tensor(traj)
193
+ if traj.ndim != 3:
194
+ raise ValueError("trajectory must be of shape (T, N, dim)")
195
+ return traj
196
+
197
+
198
+ def _plot_trajectories(ax: Any, traj: Tensor, *, step: int, label: str) -> None:
199
+ """Plot trajectories for each entity."""
200
+ if traj.numel() == 0:
201
+ return
202
+ dim = traj.shape[2]
203
+ if dim == 2:
204
+ for idx in range(traj.shape[1]):
205
+ xy = traj[::step, idx]
206
+ ax.plot(xy[:, 0], xy[:, 1], label=f"{label} {idx}")
207
+ else:
208
+ for idx in range(traj.shape[1]):
209
+ xyz = traj[::step, idx]
210
+ ax.plot(xyz[:, 0], xyz[:, 1], xyz[:, 2], label=f"{label} {idx}")
@@ -0,0 +1,173 @@
1
+ from __future__ import annotations
2
+
3
+ """Higher-level plotting utilities used by examples."""
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Sequence
7
+
8
+ import torch
9
+
10
+ from .plotting import plot_scene_dynamic, plot_scene_static
11
+
12
+
13
+ def plot_scene_and_save(
14
+ *,
15
+ out_dir: Path,
16
+ room: Sequence[float] | torch.Tensor,
17
+ sources: object | torch.Tensor | Sequence,
18
+ mics: object | torch.Tensor | Sequence,
19
+ src_traj: Optional[torch.Tensor | Sequence] = None,
20
+ mic_traj: Optional[torch.Tensor | Sequence] = None,
21
+ prefix: str = "scene",
22
+ step: int = 1,
23
+ show: bool = False,
24
+ plot_2d: bool = True,
25
+ plot_3d: bool = True,
26
+ ) -> tuple[list[Path], list[Path]]:
27
+ """Plot static and dynamic scenes and save images to disk.
28
+
29
+ Args:
30
+ out_dir: Output directory for PNGs.
31
+ room: Room size tensor or sequence.
32
+ sources: Source positions or Source-like object.
33
+ mics: Microphone positions or MicrophoneArray-like object.
34
+ src_traj: Optional source trajectory (T, n_src, dim).
35
+ mic_traj: Optional mic trajectory (T, n_mic, dim).
36
+ prefix: Filename prefix for saved images.
37
+ step: Subsampling step for trajectories.
38
+ show: Whether to show figures interactively.
39
+ plot_2d: Save 2D projections.
40
+ plot_3d: Save 3D projections (only if dim == 3).
41
+
42
+ Returns:
43
+ Tuple of (static_paths, dynamic_paths).
44
+ """
45
+ out_dir = Path(out_dir)
46
+ out_dir.mkdir(parents=True, exist_ok=True)
47
+
48
+ room_size = _to_cpu(room)
49
+ src_pos = _positions_to_cpu(sources)
50
+ mic_pos = _positions_to_cpu(mics)
51
+ dim = int(room_size.numel())
52
+
53
+ static_paths: list[Path] = []
54
+ dynamic_paths: list[Path] = []
55
+
56
+ for view_dim, enabled in ((2, plot_2d), (3, plot_3d)):
57
+ if not enabled:
58
+ continue
59
+ if view_dim == 2 and dim < 2:
60
+ continue
61
+ if view_dim == 3 and dim < 3:
62
+ continue
63
+ view_room = room_size[:view_dim]
64
+ view_src = src_pos[:, :view_dim]
65
+ view_mic = mic_pos[:, :view_dim]
66
+
67
+ ax = plot_scene_static(
68
+ room=view_room,
69
+ sources=view_src,
70
+ mics=view_mic,
71
+ title=f"Room scene ({view_dim}D static)",
72
+ show=False,
73
+ )
74
+ static_path = out_dir / f"{prefix}_static_{view_dim}d.png"
75
+ _save_axes(ax, static_path, show=show)
76
+ static_paths.append(static_path)
77
+
78
+ if src_traj is not None or mic_traj is not None:
79
+ steps = _traj_steps(src_traj, mic_traj)
80
+ src_traj = _trajectory_to_cpu(src_traj, src_pos, steps)
81
+ mic_traj = _trajectory_to_cpu(mic_traj, mic_pos, steps)
82
+ view_src_traj = src_traj[:, :, :view_dim]
83
+ view_mic_traj = mic_traj[:, :, :view_dim]
84
+ ax = plot_scene_dynamic(
85
+ room=view_room,
86
+ src_traj=view_src_traj,
87
+ mic_traj=view_mic_traj,
88
+ step=step,
89
+ title=f"Room scene ({view_dim}D trajectories)",
90
+ show=False,
91
+ )
92
+ _overlay_positions(ax, view_src, view_mic)
93
+ dynamic_path = out_dir / f"{prefix}_dynamic_{view_dim}d.png"
94
+ _save_axes(ax, dynamic_path, show=show)
95
+ dynamic_paths.append(dynamic_path)
96
+
97
+ return static_paths, dynamic_paths
98
+
99
+
100
+ def _to_cpu(value: Any) -> torch.Tensor:
101
+ """Move a value to CPU as a tensor."""
102
+ if torch.is_tensor(value):
103
+ return value.detach().cpu()
104
+ return torch.as_tensor(value).detach().cpu()
105
+
106
+
107
+ def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
108
+ """Extract positions from an entity and move to CPU."""
109
+ pos = getattr(entity, "positions", entity)
110
+ pos = _to_cpu(pos)
111
+ if pos.ndim == 1:
112
+ pos = pos.unsqueeze(0)
113
+ return pos
114
+
115
+
116
+ def _traj_steps(src_traj: Optional[torch.Tensor | Sequence], mic_traj: Optional[torch.Tensor | Sequence]) -> int:
117
+ """Infer the number of trajectory steps."""
118
+ if src_traj is not None:
119
+ return int(_to_cpu(src_traj).shape[0])
120
+ return int(_to_cpu(mic_traj).shape[0])
121
+
122
+
123
+ def _trajectory_to_cpu(
124
+ traj: Optional[torch.Tensor | Sequence], fallback_pos: torch.Tensor, steps: int
125
+ ) -> torch.Tensor:
126
+ """Normalize trajectory to CPU tensor with shape (T, N, dim)."""
127
+ if traj is None:
128
+ return fallback_pos.unsqueeze(0).repeat(steps, 1, 1)
129
+ traj = _to_cpu(traj)
130
+ if traj.ndim != 3:
131
+ raise ValueError("trajectory must be of shape (T, N, dim)")
132
+ return traj
133
+
134
+
135
+ def _save_axes(ax: Any, path: Path, *, show: bool) -> None:
136
+ """Save a matplotlib axis to disk."""
137
+ import matplotlib.pyplot as plt
138
+
139
+ fig = ax.figure
140
+ fig.tight_layout()
141
+ fig.savefig(path, dpi=150)
142
+ if show:
143
+ plt.show()
144
+ plt.close(fig)
145
+
146
+
147
+ def _overlay_positions(ax: Any, sources: torch.Tensor, mics: torch.Tensor) -> None:
148
+ """Overlay static source and mic positions on an axis."""
149
+ if sources.numel() > 0:
150
+ if sources.shape[1] == 2:
151
+ ax.scatter(sources[:, 0], sources[:, 1], marker="^", label="sources", color="tab:green")
152
+ else:
153
+ ax.scatter(
154
+ sources[:, 0],
155
+ sources[:, 1],
156
+ sources[:, 2],
157
+ marker="^",
158
+ label="sources",
159
+ color="tab:green",
160
+ )
161
+ if mics.numel() > 0:
162
+ if mics.shape[1] == 2:
163
+ ax.scatter(mics[:, 0], mics[:, 1], marker="o", label="mics", color="tab:orange")
164
+ else:
165
+ ax.scatter(
166
+ mics[:, 0],
167
+ mics[:, 1],
168
+ mics[:, 2],
169
+ marker="o",
170
+ label="mics",
171
+ color="tab:orange",
172
+ )
173
+ ax.legend(loc="best")
torchrir/results.py ADDED
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ """Result containers for simulation outputs."""
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ from torch import Tensor
9
+
10
+ from .config import SimulationConfig
11
+ from .scene import Scene
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class RIRResult:
16
+ """Container for RIRs with metadata."""
17
+
18
+ rirs: Tensor
19
+ scene: Scene
20
+ config: SimulationConfig
21
+ timestamps: Optional[Tensor] = None
22
+ seed: Optional[int] = None
torchrir/room.py ADDED
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ """Room, source, and microphone geometry models."""
4
+
5
+ from dataclasses import dataclass, replace
6
+ from typing import Optional, Sequence
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ from .utils import as_tensor, ensure_dim
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class Room:
16
+ """Room geometry and acoustic parameters."""
17
+
18
+ size: Tensor
19
+ fs: float
20
+ c: float = 343.0
21
+ beta: Optional[Tensor] = None
22
+ t60: Optional[float] = None
23
+
24
+ def __post_init__(self) -> None:
25
+ """Validate room size and reflection parameters."""
26
+ size = ensure_dim(self.size)
27
+ object.__setattr__(self, "size", size)
28
+ if self.beta is not None and self.t60 is not None:
29
+ raise ValueError("beta and t60 are mutually exclusive")
30
+
31
+ def replace(self, **kwargs) -> "Room":
32
+ """Return a new Room with updated fields."""
33
+ return replace(self, **kwargs)
34
+
35
+ @staticmethod
36
+ def shoebox(
37
+ size: Sequence[float] | Tensor,
38
+ *,
39
+ fs: float,
40
+ c: float = 343.0,
41
+ beta: Optional[Sequence[float] | Tensor] = None,
42
+ t60: Optional[float] = None,
43
+ device: Optional[torch.device | str] = None,
44
+ dtype: Optional[torch.dtype] = None,
45
+ ) -> "Room":
46
+ """Create a rectangular (shoebox) room."""
47
+ size_t = as_tensor(size, device=device, dtype=dtype)
48
+ size_t = ensure_dim(size_t)
49
+ beta_t = None
50
+ if beta is not None:
51
+ beta_t = as_tensor(beta, device=device, dtype=dtype)
52
+ return Room(size=size_t, fs=fs, c=c, beta=beta_t, t60=t60)
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class Source:
57
+ """Source container with positions and optional orientation."""
58
+
59
+ positions: Tensor
60
+ orientation: Optional[Tensor] = None
61
+
62
+ def __post_init__(self) -> None:
63
+ pos = as_tensor(self.positions)
64
+ object.__setattr__(self, "positions", pos)
65
+ if self.orientation is not None:
66
+ ori = as_tensor(self.orientation)
67
+ object.__setattr__(self, "orientation", ori)
68
+
69
+ def replace(self, **kwargs) -> "Source":
70
+ """Return a new Source with updated fields."""
71
+ return replace(self, **kwargs)
72
+
73
+ @classmethod
74
+ def positions(
75
+ cls,
76
+ positions: Sequence[Sequence[float]] | Tensor,
77
+ *,
78
+ orientation: Optional[Sequence[float] | Tensor] = None,
79
+ device: Optional[torch.device | str] = None,
80
+ dtype: Optional[torch.dtype] = None,
81
+ ) -> "Source":
82
+ """Construct a Source from positions."""
83
+ return cls.from_positions(
84
+ positions, orientation=orientation, device=device, dtype=dtype
85
+ )
86
+
87
+ @classmethod
88
+ def from_positions(
89
+ cls,
90
+ positions: Sequence[Sequence[float]] | Tensor,
91
+ *,
92
+ orientation: Optional[Sequence[float] | Tensor] = None,
93
+ device: Optional[torch.device | str] = None,
94
+ dtype: Optional[torch.dtype] = None,
95
+ ) -> "Source":
96
+ """Convert positions/orientation to tensors and build a Source."""
97
+ pos = as_tensor(positions, device=device, dtype=dtype)
98
+ ori = None
99
+ if orientation is not None:
100
+ ori = as_tensor(orientation, device=device, dtype=dtype)
101
+ return cls(pos, ori)
102
+
103
+
104
+ @dataclass(frozen=True)
105
+ class MicrophoneArray:
106
+ """Microphone array container."""
107
+
108
+ positions: Tensor
109
+ orientation: Optional[Tensor] = None
110
+
111
+ def __post_init__(self) -> None:
112
+ pos = as_tensor(self.positions)
113
+ object.__setattr__(self, "positions", pos)
114
+ if self.orientation is not None:
115
+ ori = as_tensor(self.orientation)
116
+ object.__setattr__(self, "orientation", ori)
117
+
118
+ def replace(self, **kwargs) -> "MicrophoneArray":
119
+ """Return a new MicrophoneArray with updated fields."""
120
+ return replace(self, **kwargs)
121
+
122
+ @classmethod
123
+ def positions(
124
+ cls,
125
+ positions: Sequence[Sequence[float]] | Tensor,
126
+ *,
127
+ orientation: Optional[Sequence[float] | Tensor] = None,
128
+ device: Optional[torch.device | str] = None,
129
+ dtype: Optional[torch.dtype] = None,
130
+ ) -> "MicrophoneArray":
131
+ """Construct a MicrophoneArray from positions."""
132
+ return cls.from_positions(
133
+ positions, orientation=orientation, device=device, dtype=dtype
134
+ )
135
+
136
+ @classmethod
137
+ def from_positions(
138
+ cls,
139
+ positions: Sequence[Sequence[float]] | Tensor,
140
+ *,
141
+ orientation: Optional[Sequence[float] | Tensor] = None,
142
+ device: Optional[torch.device | str] = None,
143
+ dtype: Optional[torch.dtype] = None,
144
+ ) -> "MicrophoneArray":
145
+ """Convert positions/orientation to tensors and build a MicrophoneArray."""
146
+ pos = as_tensor(positions, device=device, dtype=dtype)
147
+ ori = None
148
+ if orientation is not None:
149
+ ori = as_tensor(orientation, device=device, dtype=dtype)
150
+ return cls(pos, ori)
torchrir/scene.py ADDED
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ """Scene container for simulation inputs."""
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ from .room import MicrophoneArray, Room, Source
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class Scene:
16
+ """Container for room, sources, microphones, and optional trajectories."""
17
+
18
+ room: Room
19
+ sources: Source
20
+ mics: MicrophoneArray
21
+ src_traj: Optional[Tensor] = None
22
+ mic_traj: Optional[Tensor] = None
23
+
24
+ def is_dynamic(self) -> bool:
25
+ """Return True if any trajectory is provided."""
26
+ return self.src_traj is not None or self.mic_traj is not None
27
+
28
+ def validate(self) -> None:
29
+ """Validate scene consistency and trajectory shapes."""
30
+ if not isinstance(self.room, Room):
31
+ raise TypeError("room must be a Room instance")
32
+ if not isinstance(self.sources, Source):
33
+ raise TypeError("sources must be a Source instance")
34
+ if not isinstance(self.mics, MicrophoneArray):
35
+ raise TypeError("mics must be a MicrophoneArray instance")
36
+
37
+ dim = self.room.size.numel()
38
+ n_src = self.sources.positions.shape[0]
39
+ n_mic = self.mics.positions.shape[0]
40
+
41
+ t_src = _validate_traj(self.src_traj, n_src, dim, "src_traj")
42
+ t_mic = _validate_traj(self.mic_traj, n_mic, dim, "mic_traj")
43
+ if t_src is not None and t_mic is not None and t_src != t_mic:
44
+ raise ValueError("src_traj and mic_traj must have matching time steps")
45
+
46
+
47
+ def _validate_traj(
48
+ traj: Optional[Tensor],
49
+ count: int,
50
+ dim: int,
51
+ name: str,
52
+ ) -> Optional[int]:
53
+ if traj is None:
54
+ return None
55
+ if not torch.is_tensor(traj):
56
+ raise TypeError(f"{name} must be a Tensor")
57
+ if traj.ndim == 2:
58
+ if count != 1:
59
+ raise ValueError(f"{name} must have shape (T, {count}, {dim})")
60
+ if traj.shape[1] != dim:
61
+ raise ValueError(f"{name} must have shape (T, {dim}) for single entity")
62
+ return traj.shape[0]
63
+ if traj.ndim == 3:
64
+ if traj.shape[1] != count or traj.shape[2] != dim:
65
+ raise ValueError(f"{name} must have shape (T, {count}, {dim})")
66
+ return traj.shape[0]
67
+ raise ValueError(f"{name} must have shape (T, {count}, {dim})")
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ """Scene-agnostic utilities for example setups."""
4
+
5
+ import random
6
+ from typing import List
7
+
8
+ import torch
9
+
10
+
11
+ def sample_positions(
12
+ *,
13
+ num: int,
14
+ room_size: torch.Tensor,
15
+ rng: random.Random,
16
+ margin: float = 0.5,
17
+ ) -> torch.Tensor:
18
+ """Sample random positions within a room with a safety margin."""
19
+ dim = room_size.numel()
20
+ low = [margin] * dim
21
+ high = [float(room_size[i].item()) - margin for i in range(dim)]
22
+ coords: List[List[float]] = []
23
+ for _ in range(num):
24
+ point = [rng.uniform(low[i], high[i]) for i in range(dim)]
25
+ coords.append(point)
26
+ return torch.tensor(coords, dtype=torch.float32)
27
+
28
+
29
+ def linear_trajectory(start: torch.Tensor, end: torch.Tensor, steps: int) -> torch.Tensor:
30
+ """Create a linear trajectory between start and end."""
31
+ return torch.stack(
32
+ [start + (end - start) * t / (steps - 1) for t in range(steps)],
33
+ dim=0,
34
+ )
35
+
36
+
37
+ def binaural_mic_positions(center: torch.Tensor, offset: float = 0.08) -> torch.Tensor:
38
+ """Create a two-mic binaural layout around a center point."""
39
+ dim = center.numel()
40
+ offset_vec = torch.zeros((dim,), dtype=torch.float32)
41
+ offset_vec[0] = offset
42
+ left = center - offset_vec
43
+ right = center + offset_vec
44
+ return torch.stack([left, right], dim=0)
45
+
46
+
47
+ def clamp_positions(positions: torch.Tensor, room_size: torch.Tensor, margin: float = 0.1) -> torch.Tensor:
48
+ """Clamp positions to remain inside the room with a margin."""
49
+ min_v = torch.full_like(room_size, margin)
50
+ max_v = room_size - margin
51
+ return torch.max(torch.min(positions, max_v), min_v)