torchrir 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrir/metadata.py ADDED
@@ -0,0 +1,216 @@
1
+ from __future__ import annotations
2
+
3
+ """Metadata helpers for simulation outputs."""
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional
8
+
9
+ import json
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from .room import MicrophoneArray, Room, Source
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ArrayAttributes:
18
+ """Structured description of a microphone array."""
19
+
20
+ geometry_name: str
21
+ positions: Tensor
22
+ orientation: Optional[Tensor]
23
+ center: Tensor
24
+ normal: Optional[Tensor]
25
+ spacing: Optional[float]
26
+
27
+
28
+ def build_metadata(
29
+ *,
30
+ room: Room,
31
+ sources: Source,
32
+ mics: MicrophoneArray,
33
+ rirs: Tensor,
34
+ src_traj: Optional[Tensor] = None,
35
+ mic_traj: Optional[Tensor] = None,
36
+ timestamps: Optional[Tensor] = None,
37
+ signal_len: Optional[int] = None,
38
+ source_info: Optional[Any] = None,
39
+ extra: Optional[Dict[str, Any]] = None,
40
+ ) -> Dict[str, Any]:
41
+ """Build JSON-serializable metadata for a simulation output.
42
+
43
+ Example:
44
+ >>> metadata = build_metadata(
45
+ ... room=room,
46
+ ... sources=sources,
47
+ ... mics=mics,
48
+ ... rirs=rirs,
49
+ ... src_traj=src_traj,
50
+ ... mic_traj=mic_traj,
51
+ ... signal_len=signal.shape[-1],
52
+ ... )
53
+ >>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
54
+ """
55
+ nsample = int(rirs.shape[-1])
56
+ fs = float(room.fs)
57
+ time_axis = {
58
+ "fs": fs,
59
+ "nsample": nsample,
60
+ "t": _to_serializable(torch.arange(nsample, dtype=torch.float32) / fs),
61
+ }
62
+
63
+ src_pos = sources.positions
64
+ mic_pos = mics.positions
65
+ dim = int(room.size.numel())
66
+ src_traj_n = _normalize_traj(src_traj, src_pos, dim, "src_traj")
67
+ mic_traj_n = _normalize_traj(mic_traj, mic_pos, dim, "mic_traj")
68
+
69
+ t_steps = max(src_traj_n.shape[0], mic_traj_n.shape[0])
70
+ if src_traj_n.shape[0] == 1 and t_steps > 1:
71
+ src_traj_n = src_traj_n.expand(t_steps, -1, -1)
72
+ if mic_traj_n.shape[0] == 1 and t_steps > 1:
73
+ mic_traj_n = mic_traj_n.expand(t_steps, -1, -1)
74
+ if src_traj_n.shape[0] != mic_traj_n.shape[0]:
75
+ raise ValueError("src_traj and mic_traj must have matching time steps")
76
+
77
+ azimuth, elevation = _compute_doa(src_traj_n, mic_traj_n)
78
+ doa = {
79
+ "frame": "world",
80
+ "unit": "radians",
81
+ "azimuth": _to_serializable(azimuth),
82
+ "elevation": _to_serializable(elevation),
83
+ }
84
+
85
+ timestamps_out: Optional[Tensor] = None
86
+ if timestamps is not None:
87
+ timestamps_out = timestamps
88
+ elif t_steps > 1 and signal_len is not None:
89
+ duration = max(0.0, (float(signal_len) - 1.0) / fs)
90
+ timestamps_out = torch.linspace(0.0, duration, t_steps, dtype=torch.float32)
91
+
92
+ array_attrs = _array_attributes(mics)
93
+
94
+ metadata: Dict[str, Any] = {
95
+ "room": {
96
+ "size": _to_serializable(room.size),
97
+ "c": float(room.c),
98
+ "beta": _to_serializable(room.beta) if room.beta is not None else None,
99
+ "t60": float(room.t60) if room.t60 is not None else None,
100
+ "fs": fs,
101
+ },
102
+ "sources": {
103
+ "positions": _to_serializable(src_pos),
104
+ "orientation": _to_serializable(sources.orientation),
105
+ },
106
+ "mics": {
107
+ "positions": _to_serializable(mic_pos),
108
+ "orientation": _to_serializable(mics.orientation),
109
+ },
110
+ "trajectories": {
111
+ "sources": _to_serializable(src_traj_n if t_steps > 1 else None),
112
+ "mics": _to_serializable(mic_traj_n if t_steps > 1 else None),
113
+ },
114
+ "array": {
115
+ "geometry": array_attrs.geometry_name,
116
+ "positions": _to_serializable(array_attrs.positions),
117
+ "orientation": _to_serializable(array_attrs.orientation),
118
+ "center": _to_serializable(array_attrs.center),
119
+ "normal": _to_serializable(array_attrs.normal),
120
+ "spacing": array_attrs.spacing,
121
+ },
122
+ "time_axis": time_axis,
123
+ "doa": doa,
124
+ "timestamps": _to_serializable(timestamps_out),
125
+ "rirs_shape": list(rirs.shape),
126
+ "dynamic": bool(t_steps > 1),
127
+ }
128
+
129
+ if source_info is not None:
130
+ metadata["source_info"] = _to_serializable(source_info)
131
+ if extra:
132
+ metadata["extra"] = _to_serializable(extra)
133
+ return metadata
134
+
135
+
136
+ def save_metadata_json(path: Path, metadata: Dict[str, Any]) -> None:
137
+ """Save metadata as JSON to the given path.
138
+
139
+ Example:
140
+ >>> save_metadata_json(Path(\"outputs/scene_metadata.json\"), metadata)
141
+ """
142
+ path.parent.mkdir(parents=True, exist_ok=True)
143
+ with path.open("w", encoding="utf-8") as f:
144
+ json.dump(metadata, f, indent=2)
145
+
146
+
147
+ def _normalize_traj(traj: Optional[Tensor], pos: Tensor, dim: int, name: str) -> Tensor:
148
+ if traj is None:
149
+ if pos.ndim != 2 or pos.shape[1] != dim:
150
+ raise ValueError(f"{name} default positions must have shape (N, {dim})")
151
+ return pos.unsqueeze(0)
152
+ if not torch.is_tensor(traj):
153
+ raise TypeError(f"{name} must be a Tensor")
154
+ if traj.ndim == 2:
155
+ if traj.shape[1] != dim:
156
+ raise ValueError(f"{name} must have shape (T, {dim})")
157
+ return traj.unsqueeze(1)
158
+ if traj.ndim == 3:
159
+ if traj.shape[2] != dim:
160
+ raise ValueError(f"{name} must have shape (T, N, {dim})")
161
+ return traj
162
+ raise ValueError(f"{name} must have shape (T, N, {dim})")
163
+
164
+
165
+ def _compute_doa(src_traj: Tensor, mic_traj: Tensor) -> tuple[Tensor, Tensor]:
166
+ vec = src_traj[:, :, None, :] - mic_traj[:, None, :, :]
167
+ x = vec[..., 0]
168
+ y = vec[..., 1]
169
+ azimuth = torch.atan2(y, x)
170
+ if vec.shape[-1] < 3:
171
+ elevation = torch.zeros_like(azimuth)
172
+ else:
173
+ z = vec[..., 2]
174
+ r_xy = torch.sqrt(x**2 + y**2)
175
+ elevation = torch.atan2(z, r_xy)
176
+ return azimuth, elevation
177
+
178
+
179
+ def _array_attributes(mics: MicrophoneArray) -> ArrayAttributes:
180
+ pos = mics.positions
181
+ n_mic = pos.shape[0]
182
+ if n_mic == 1:
183
+ geometry = "single"
184
+ elif n_mic == 2:
185
+ geometry = "binaural"
186
+ else:
187
+ geometry = "custom"
188
+ center = pos.mean(dim=0)
189
+ spacing = None
190
+ if n_mic >= 2:
191
+ dists = torch.cdist(pos, pos)
192
+ dists = dists[dists > 0]
193
+ if dists.numel() > 0:
194
+ spacing = float(dists.min().item())
195
+ return ArrayAttributes(
196
+ geometry_name=geometry,
197
+ positions=pos,
198
+ orientation=mics.orientation,
199
+ center=center,
200
+ normal=None,
201
+ spacing=spacing,
202
+ )
203
+
204
+
205
+ def _to_serializable(value: Any) -> Any:
206
+ if value is None:
207
+ return None
208
+ if torch.is_tensor(value):
209
+ return value.detach().cpu().tolist()
210
+ if isinstance(value, Path):
211
+ return str(value)
212
+ if isinstance(value, dict):
213
+ return {k: _to_serializable(v) for k, v in value.items()}
214
+ if isinstance(value, (list, tuple)):
215
+ return [_to_serializable(v) for v in value]
216
+ return value
torchrir/plotting.py CHANGED
@@ -20,7 +20,15 @@ def plot_scene_static(
20
20
  title: Optional[str] = None,
21
21
  show: bool = False,
22
22
  ):
23
- """Plot a static room with source and mic positions."""
23
+ """Plot a static room with source and mic positions.
24
+
25
+ Example:
26
+ >>> ax = plot_scene_static(
27
+ ... room=[6.0, 4.0, 3.0],
28
+ ... sources=[[1.0, 2.0, 1.5]],
29
+ ... mics=[[2.0, 2.0, 1.5]],
30
+ ... )
31
+ """
24
32
  plt, ax = _setup_axes(ax, room)
25
33
 
26
34
  size = _room_size(room, ax)
@@ -46,21 +54,35 @@ def plot_scene_dynamic(
46
54
  src_traj: Tensor | Sequence,
47
55
  mic_traj: Tensor | Sequence,
48
56
  step: int = 1,
57
+ src_pos: Optional[Tensor | Sequence] = None,
58
+ mic_pos: Optional[Tensor | Sequence] = None,
49
59
  ax: Any | None = None,
50
60
  title: Optional[str] = None,
51
61
  show: bool = False,
52
62
  ):
53
- """Plot source and mic trajectories within a room."""
63
+ """Plot source and mic trajectories within a room.
64
+
65
+ If trajectories are static, only positions are plotted.
66
+
67
+ Example:
68
+ >>> ax = plot_scene_dynamic(
69
+ ... room=[6.0, 4.0, 3.0],
70
+ ... src_traj=src_traj,
71
+ ... mic_traj=mic_traj,
72
+ ... )
73
+ """
54
74
  plt, ax = _setup_axes(ax, room)
55
75
 
56
76
  size = _room_size(room, ax)
57
77
  _draw_room(ax, size)
58
78
 
59
- src_traj = _as_trajectory(src_traj, ax)
60
- mic_traj = _as_trajectory(mic_traj, ax)
79
+ src_traj = _as_trajectory(src_traj)
80
+ mic_traj = _as_trajectory(mic_traj)
81
+ src_pos_t = _extract_positions(src_pos, ax) if src_pos is not None else src_traj[0]
82
+ mic_pos_t = _extract_positions(mic_pos, ax) if mic_pos is not None else mic_traj[0]
61
83
 
62
- _plot_trajectories(ax, src_traj, step=step, label="source path")
63
- _plot_trajectories(ax, mic_traj, step=step, label="mic path")
84
+ _plot_entity(ax, src_traj, src_pos_t, step=step, label="sources", marker="^")
85
+ _plot_entity(ax, mic_traj, mic_pos_t, step=step, label="mics", marker="o")
64
86
 
65
87
  if title:
66
88
  ax.set_title(title)
@@ -70,7 +92,9 @@ def plot_scene_dynamic(
70
92
  return ax
71
93
 
72
94
 
73
- def _setup_axes(ax: Any | None, room: Room | Sequence[float] | Tensor) -> tuple[Any, Any]:
95
+ def _setup_axes(
96
+ ax: Any | None, room: Room | Sequence[float] | Tensor
97
+ ) -> tuple[Any, Any]:
74
98
  """Create 2D/3D axes based on room dimension."""
75
99
  import matplotlib.pyplot as plt
76
100
 
@@ -109,8 +133,9 @@ def _draw_room_2d(ax: Any, size: Tensor) -> None:
109
133
  """Draw a 2D rectangular room."""
110
134
  import matplotlib.patches as patches
111
135
 
112
- rect = patches.Rectangle((0.0, 0.0), size[0].item(), size[1].item(),
113
- fill=False, edgecolor="black")
136
+ rect = patches.Rectangle(
137
+ (0.0, 0.0), size[0].item(), size[1].item(), fill=False, edgecolor="black"
138
+ )
114
139
  ax.add_patch(rect)
115
140
  ax.set_xlim(0, size[0].item())
116
141
  ax.set_ylim(0, size[1].item())
@@ -164,7 +189,9 @@ def _draw_room_3d(ax: Any, size: Tensor) -> None:
164
189
  ax.set_zlabel("z")
165
190
 
166
191
 
167
- def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None) -> Tensor:
192
+ def _extract_positions(
193
+ entity: Source | MicrophoneArray | Tensor | Sequence, ax: Any | None
194
+ ) -> Tensor:
168
195
  """Extract positions from Source/MicrophoneArray or raw tensor."""
169
196
  if isinstance(entity, (Source, MicrophoneArray)):
170
197
  pos = entity.positions
@@ -176,18 +203,34 @@ def _extract_positions(entity: Source | MicrophoneArray | Tensor | Sequence, ax:
176
203
  return pos
177
204
 
178
205
 
179
- def _scatter_positions(ax: Any, positions: Tensor, *, label: str, marker: str) -> None:
206
+ def _scatter_positions(
207
+ ax: Any,
208
+ positions: Tensor,
209
+ *,
210
+ label: str,
211
+ marker: str,
212
+ color: Optional[str] = None,
213
+ ) -> None:
180
214
  """Scatter-plot positions in 2D or 3D."""
181
215
  if positions.numel() == 0:
182
216
  return
183
217
  dim = positions.shape[1]
184
218
  if dim == 2:
185
- ax.scatter(positions[:, 0], positions[:, 1], label=label, marker=marker)
219
+ ax.scatter(
220
+ positions[:, 0], positions[:, 1], label=label, marker=marker, color=color
221
+ )
186
222
  else:
187
- ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], label=label, marker=marker)
223
+ ax.scatter(
224
+ positions[:, 0],
225
+ positions[:, 1],
226
+ positions[:, 2],
227
+ label=label,
228
+ marker=marker,
229
+ color=color,
230
+ )
188
231
 
189
232
 
190
- def _as_trajectory(traj: Tensor | Sequence, ax: Any | None) -> Tensor:
233
+ def _as_trajectory(traj: Tensor | Sequence) -> Tensor:
191
234
  """Validate and normalize a trajectory tensor."""
192
235
  traj = as_tensor(traj)
193
236
  if traj.ndim != 3:
@@ -195,16 +238,73 @@ def _as_trajectory(traj: Tensor | Sequence, ax: Any | None) -> Tensor:
195
238
  return traj
196
239
 
197
240
 
198
- def _plot_trajectories(ax: Any, traj: Tensor, *, step: int, label: str) -> None:
199
- """Plot trajectories for each entity."""
241
+ def _plot_entity(
242
+ ax: Any,
243
+ traj: Tensor,
244
+ positions: Tensor,
245
+ *,
246
+ step: int,
247
+ label: str,
248
+ marker: str,
249
+ ) -> None:
250
+ """Plot trajectories and/or static positions with a unified legend entry."""
200
251
  if traj.numel() == 0:
201
252
  return
202
- dim = traj.shape[2]
203
- if dim == 2:
204
- for idx in range(traj.shape[1]):
205
- xy = traj[::step, idx]
206
- ax.plot(xy[:, 0], xy[:, 1], label=f"{label} {idx}")
253
+ import matplotlib.pyplot as plt
254
+
255
+ if positions.shape != traj.shape[1:]:
256
+ positions = traj[0]
257
+ moving = _is_moving(traj, positions)
258
+ colors = plt.rcParams.get("axes.prop_cycle", None)
259
+ if colors is not None:
260
+ palette = colors.by_key().get("color", [])
207
261
  else:
208
- for idx in range(traj.shape[1]):
209
- xyz = traj[::step, idx]
210
- ax.plot(xyz[:, 0], xyz[:, 1], xyz[:, 2], label=f"{label} {idx}")
262
+ palette = []
263
+ if not palette:
264
+ palette = ["C0", "C1", "C2", "C3", "C4", "C5"]
265
+
266
+ dim = traj.shape[2]
267
+ for idx in range(traj.shape[1]):
268
+ color = palette[idx % len(palette)]
269
+ lbl = label if idx == 0 else "_nolegend_"
270
+ if moving:
271
+ if dim == 2:
272
+ xy = traj[::step, idx]
273
+ ax.plot(
274
+ xy[:, 0],
275
+ xy[:, 1],
276
+ label=lbl,
277
+ color=color,
278
+ marker=marker,
279
+ markevery=[0],
280
+ )
281
+ else:
282
+ xyz = traj[::step, idx]
283
+ ax.plot(
284
+ xyz[:, 0],
285
+ xyz[:, 1],
286
+ xyz[:, 2],
287
+ label=lbl,
288
+ color=color,
289
+ marker=marker,
290
+ markevery=[0],
291
+ )
292
+ pos = positions[idx : idx + 1]
293
+ _scatter_positions(ax, pos, label="_nolegend_", marker=marker, color=color)
294
+ if not moving:
295
+ # ensure legend marker uses the group label
296
+ _scatter_positions(
297
+ ax,
298
+ positions[:1],
299
+ label=label,
300
+ marker=marker,
301
+ color=palette[0],
302
+ )
303
+
304
+
305
+ def _is_moving(traj: Tensor, positions: Tensor, *, tol: float = 1e-6) -> bool:
306
+ """Return True if any trajectory deviates from the provided positions."""
307
+ if traj.numel() == 0:
308
+ return False
309
+ pos0 = positions.unsqueeze(0).expand_as(traj)
310
+ return bool(torch.any(torch.linalg.norm(traj - pos0, dim=-1) > tol).item())
@@ -26,6 +26,8 @@ def plot_scene_and_save(
26
26
  ) -> tuple[list[Path], list[Path]]:
27
27
  """Plot static and dynamic scenes and save images to disk.
28
28
 
29
+ Dynamic plots show trajectories for moving entities and points for fixed ones.
30
+
29
31
  Args:
30
32
  out_dir: Output directory for PNGs.
31
33
  room: Room size tensor or sequence.
@@ -41,6 +43,17 @@ def plot_scene_and_save(
41
43
 
42
44
  Returns:
43
45
  Tuple of (static_paths, dynamic_paths).
46
+
47
+ Example:
48
+ >>> plot_scene_and_save(
49
+ ... out_dir=Path("outputs"),
50
+ ... room=[6.0, 4.0, 3.0],
51
+ ... sources=[[1.0, 2.0, 1.5]],
52
+ ... mics=[[2.0, 2.0, 1.5]],
53
+ ... src_traj=src_traj,
54
+ ... mic_traj=mic_traj,
55
+ ... prefix="scene",
56
+ ... )
44
57
  """
45
58
  out_dir = Path(out_dir)
46
59
  out_dir.mkdir(parents=True, exist_ok=True)
@@ -85,11 +98,12 @@ def plot_scene_and_save(
85
98
  room=view_room,
86
99
  src_traj=view_src_traj,
87
100
  mic_traj=view_mic_traj,
101
+ src_pos=view_src,
102
+ mic_pos=view_mic,
88
103
  step=step,
89
104
  title=f"Room scene ({view_dim}D trajectories)",
90
105
  show=False,
91
106
  )
92
- _overlay_positions(ax, view_src, view_mic)
93
107
  dynamic_path = out_dir / f"{prefix}_dynamic_{view_dim}d.png"
94
108
  _save_axes(ax, dynamic_path, show=show)
95
109
  dynamic_paths.append(dynamic_path)
@@ -113,7 +127,10 @@ def _positions_to_cpu(entity: torch.Tensor | object) -> torch.Tensor:
113
127
  return pos
114
128
 
115
129
 
116
- def _traj_steps(src_traj: Optional[torch.Tensor | Sequence], mic_traj: Optional[torch.Tensor | Sequence]) -> int:
130
+ def _traj_steps(
131
+ src_traj: Optional[torch.Tensor | Sequence],
132
+ mic_traj: Optional[torch.Tensor | Sequence],
133
+ ) -> int:
117
134
  """Infer the number of trajectory steps."""
118
135
  if src_traj is not None:
119
136
  return int(_to_cpu(src_traj).shape[0])
@@ -142,32 +159,3 @@ def _save_axes(ax: Any, path: Path, *, show: bool) -> None:
142
159
  if show:
143
160
  plt.show()
144
161
  plt.close(fig)
145
-
146
-
147
- def _overlay_positions(ax: Any, sources: torch.Tensor, mics: torch.Tensor) -> None:
148
- """Overlay static source and mic positions on an axis."""
149
- if sources.numel() > 0:
150
- if sources.shape[1] == 2:
151
- ax.scatter(sources[:, 0], sources[:, 1], marker="^", label="sources", color="tab:green")
152
- else:
153
- ax.scatter(
154
- sources[:, 0],
155
- sources[:, 1],
156
- sources[:, 2],
157
- marker="^",
158
- label="sources",
159
- color="tab:green",
160
- )
161
- if mics.numel() > 0:
162
- if mics.shape[1] == 2:
163
- ax.scatter(mics[:, 0], mics[:, 1], marker="o", label="mics", color="tab:orange")
164
- else:
165
- ax.scatter(
166
- mics[:, 0],
167
- mics[:, 1],
168
- mics[:, 2],
169
- marker="o",
170
- label="mics",
171
- color="tab:orange",
172
- )
173
- ax.legend(loc="best")
torchrir/results.py CHANGED
@@ -13,7 +13,13 @@ from .scene import Scene
13
13
 
14
14
  @dataclass(frozen=True)
15
15
  class RIRResult:
16
- """Container for RIRs with metadata."""
16
+ """Container for RIRs with metadata.
17
+
18
+ Example:
19
+ >>> from torchrir import ISMSimulator
20
+ >>> result = ISMSimulator().simulate(scene, config)
21
+ >>> rirs = result.rirs
22
+ """
17
23
 
18
24
  rirs: Tensor
19
25
  scene: Scene
torchrir/room.py CHANGED
@@ -13,7 +13,11 @@ from .utils import as_tensor, ensure_dim
13
13
 
14
14
  @dataclass(frozen=True)
15
15
  class Room:
16
- """Room geometry and acoustic parameters."""
16
+ """Room geometry and acoustic parameters.
17
+
18
+ Example:
19
+ >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
20
+ """
17
21
 
18
22
  size: Tensor
19
23
  fs: float
@@ -43,7 +47,11 @@ class Room:
43
47
  device: Optional[torch.device | str] = None,
44
48
  dtype: Optional[torch.dtype] = None,
45
49
  ) -> "Room":
46
- """Create a rectangular (shoebox) room."""
50
+ """Create a rectangular (shoebox) room.
51
+
52
+ Example:
53
+ >>> room = Room.shoebox(size=[6.0, 4.0, 3.0], fs=16000, beta=[0.9] * 6)
54
+ """
47
55
  size_t = as_tensor(size, device=device, dtype=dtype)
48
56
  size_t = ensure_dim(size_t)
49
57
  beta_t = None
@@ -54,7 +62,11 @@ class Room:
54
62
 
55
63
  @dataclass(frozen=True)
56
64
  class Source:
57
- """Source container with positions and optional orientation."""
65
+ """Source container with positions and optional orientation.
66
+
67
+ Example:
68
+ >>> sources = Source.from_positions([[1.0, 2.0, 1.5]])
69
+ """
58
70
 
59
71
  positions: Tensor
60
72
  orientation: Optional[Tensor] = None
@@ -70,20 +82,6 @@ class Source:
70
82
  """Return a new Source with updated fields."""
71
83
  return replace(self, **kwargs)
72
84
 
73
- @classmethod
74
- def positions(
75
- cls,
76
- positions: Sequence[Sequence[float]] | Tensor,
77
- *,
78
- orientation: Optional[Sequence[float] | Tensor] = None,
79
- device: Optional[torch.device | str] = None,
80
- dtype: Optional[torch.dtype] = None,
81
- ) -> "Source":
82
- """Construct a Source from positions."""
83
- return cls.from_positions(
84
- positions, orientation=orientation, device=device, dtype=dtype
85
- )
86
-
87
85
  @classmethod
88
86
  def from_positions(
89
87
  cls,
@@ -103,7 +101,11 @@ class Source:
103
101
 
104
102
  @dataclass(frozen=True)
105
103
  class MicrophoneArray:
106
- """Microphone array container."""
104
+ """Microphone array container.
105
+
106
+ Example:
107
+ >>> mics = MicrophoneArray.from_positions([[2.0, 2.0, 1.5]])
108
+ """
107
109
 
108
110
  positions: Tensor
109
111
  orientation: Optional[Tensor] = None
@@ -119,20 +121,6 @@ class MicrophoneArray:
119
121
  """Return a new MicrophoneArray with updated fields."""
120
122
  return replace(self, **kwargs)
121
123
 
122
- @classmethod
123
- def positions(
124
- cls,
125
- positions: Sequence[Sequence[float]] | Tensor,
126
- *,
127
- orientation: Optional[Sequence[float] | Tensor] = None,
128
- device: Optional[torch.device | str] = None,
129
- dtype: Optional[torch.dtype] = None,
130
- ) -> "MicrophoneArray":
131
- """Construct a MicrophoneArray from positions."""
132
- return cls.from_positions(
133
- positions, orientation=orientation, device=device, dtype=dtype
134
- )
135
-
136
124
  @classmethod
137
125
  def from_positions(
138
126
  cls,
torchrir/scene.py CHANGED
@@ -13,7 +13,12 @@ from .room import MicrophoneArray, Room, Source
13
13
 
14
14
  @dataclass(frozen=True)
15
15
  class Scene:
16
- """Container for room, sources, microphones, and optional trajectories."""
16
+ """Container for room, sources, microphones, and optional trajectories.
17
+
18
+ Example:
19
+ >>> scene = Scene(room=room, sources=sources, mics=mics, src_traj=src_traj, mic_traj=mic_traj)
20
+ >>> scene.validate()
21
+ """
17
22
 
18
23
  room: Room
19
24
  sources: Source