FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,294 @@
1
+ """Array-API utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from types import EllipsisType
6
+ from typing import Any, Literal, Protocol, Self, cast, runtime_checkable
7
+
8
+ from array_api_compat import array_namespace as xpc_array_namespace
9
+
10
+ from fast_simus.backends.mlx import ensure_compat as _ensure_mlx_compat
11
+
12
+
13
+ @runtime_checkable
14
+ class LinAlg(Protocol):
15
+ """Protocol for linear algebra extension conforming to Array API standard.
16
+
17
+ This is an optional extension - not all array libraries implement it.
18
+ See: https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html
19
+ """
20
+
21
+ def vector_norm(self, x: Array, *, axis: Any = None, keepdims: bool = False, ord: Any = None) -> Array: ...
22
+
23
+
24
+ @runtime_checkable
25
+ class FFT(Protocol):
26
+ """Protocol for FFT extension conforming to Array API standard.
27
+
28
+ This is an optional extension - not all array libraries implement it.
29
+ See: https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html
30
+
31
+ This protocol includes only the FFT functions commonly used in PyMUST:
32
+ - rfft/irfft: Real FFT for efficient processing of real-valued signals
33
+ - rfftfreq: Frequency bins for real FFT
34
+ - fftshift: Shift zero-frequency component to center
35
+ """
36
+
37
+ def rfft(self, x: Array, /, *, n: int | None = None, axis: int = -1, norm: str = "backward") -> Array: ...
38
+ def irfft(self, x: Array, /, *, n: int | None = None, axis: int = -1, norm: str = "backward") -> Array: ...
39
+ def rfftfreq(self, n: int, /, *, d: float = 1.0, device: Any = None) -> Array: ...
40
+ def fftshift(self, x: Array, /, *, axes: int | tuple[int, ...] | None = None) -> Array: ...
41
+
42
+
43
+ @runtime_checkable
44
+ class _ArrayNamespace(Protocol):
45
+ """Protocol for array namespaces that conform to the Array API standard.
46
+
47
+ This covers the common operations and data types used throughout the FastSIMUS codebase.
48
+ Based on the Array API specification: https://data-apis.org/array-api/latest/
49
+
50
+ Notes:
51
+ - This base protocol does NOT include optional extensions (linalg, fft).
52
+ Use the extended protocols for namespaces that have these extensions.
53
+ - We remove some Array API standard functions that are not used in the codebase,
54
+ to be lenient about other array libraries like MLX
55
+ """
56
+
57
+ # Data types
58
+ float32: Any
59
+ float64: Any
60
+ complex64: Any
61
+ int8: Any
62
+ int16: Any
63
+ int32: Any
64
+ int64: Any
65
+ uint8: Any
66
+ uint16: Any
67
+ uint32: Any
68
+ uint64: Any
69
+
70
+ # Constants
71
+ pi: float
72
+
73
+ # Creation functions
74
+ def asarray(self, obj: Any, *, dtype: Any = None, device: Any = None, copy: bool = False) -> Array: ...
75
+ def arange(
76
+ self,
77
+ start: int | float,
78
+ /,
79
+ stop: int | float | None = None,
80
+ step: int | float = 1,
81
+ *,
82
+ dtype: Any = None,
83
+ device: Any = None,
84
+ ) -> Array: ...
85
+ def linspace(
86
+ self,
87
+ start: int | float | complex,
88
+ stop: int | float | complex,
89
+ /,
90
+ num: int,
91
+ *,
92
+ dtype: Any = None,
93
+ device: Any = None,
94
+ endpoint: bool = True,
95
+ ) -> Array: ...
96
+ def meshgrid(self, *arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]: ...
97
+ def ones(self, shape: int | tuple[int, ...], *, dtype: Any = None, device: Any = None) -> Array: ...
98
+ def ones_like(self, x: Array, /, *, dtype: Any = None, device: Any = None) -> Array: ...
99
+ def zeros(self, shape: Any, *, dtype: Any = None, device: Any = None) -> Array: ...
100
+ def zeros_like(self, x: Array, /, *, dtype: Any = None, device: Any = None) -> Array: ...
101
+
102
+ # Element-wise functions
103
+ def abs(self, x: Array, /) -> Array: ...
104
+ def asin(self, x: Array, /) -> Array: ...
105
+ def atan2(self, x1: Array, x2: Array, /) -> Array: ...
106
+ def conj(self, x: Array, /) -> Array: ...
107
+ def cos(self, x: Array) -> Array: ...
108
+ def exp(self, x: Array, /) -> Array: ...
109
+ def floor(self, x: Array, /) -> Array: ...
110
+ def isfinite(self, x: Array, /) -> Array: ...
111
+ def isnan(self, x: Array, /) -> Array: ...
112
+ def log(self, x: Array, /) -> Array: ...
113
+ def log10(self, x: Array, /) -> Array: ...
114
+ def max(self, x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: ...
115
+ def mean(self, x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: ...
116
+ def min(self, x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: ...
117
+ def real(self, x: Array, /) -> Array: ...
118
+ def sign(self, x: Array) -> Array: ...
119
+ def sin(self, x: Array) -> Array: ...
120
+ def sqrt(self, x: Array) -> Array: ...
121
+ def sum(self, x: Array, *, axis: Any = None, keepdims: bool = False) -> Array: ...
122
+
123
+ # Indexing functions
124
+ def take(self, x: Array, indices: Array, /, *, axis: int | None = None) -> Array: ...
125
+
126
+ # Linear algebra functions (part of main API)
127
+ def matmul(self, x1: Array, x2: Array) -> Array: ...
128
+ def tensordot(self, x1: Array, x2: Array, *, axes: Any = 2) -> Array: ...
129
+
130
+ # Manipulation functions
131
+ def broadcast_to(self, x: Array, /, shape: tuple[int, ...]) -> Array: ...
132
+ def reshape(self, x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: ...
133
+ def stack(self, arrays: Any, *, axis: int = 0) -> Array: ...
134
+
135
+ # Searching functions
136
+ def argmax(self, x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array: ...
137
+ def where(self, condition: Array, x1: Array, x2: Array, /) -> Array: ...
138
+
139
+ # Utility functions
140
+ def all(self, x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: ...
141
+
142
+
143
+ @runtime_checkable
144
+ class _ArrayNamespaceWithLinAlg(_ArrayNamespace, Protocol):
145
+ """Extended _ArrayNamespace protocol that includes the linear algebra extension.
146
+
147
+ Use this when you know the array namespace supports the linalg extension.
148
+ Most code should use the base _ArrayNamespace and check with hasattr(xp, 'linalg').
149
+ """
150
+
151
+ linalg: LinAlg
152
+
153
+
154
+ @runtime_checkable
155
+ class _ArrayNamespaceWithFFT(_ArrayNamespace, Protocol):
156
+ """Extended _ArrayNamespace protocol that includes the FFT extension.
157
+
158
+ Use this when you know the array namespace supports the fft extension.
159
+ Most code should use the base _ArrayNamespace and check with hasattr(xp, 'fft').
160
+ """
161
+
162
+ fft: FFT
163
+
164
+
165
+ @runtime_checkable
166
+ class _ArrayNamespaceWithLinAlgAndFFT(_ArrayNamespace, Protocol):
167
+ """Extended _ArrayNamespace protocol that includes both linalg and fft extensions.
168
+
169
+ Use this when you know the array namespace supports both extensions.
170
+ Most code should use the base _ArrayNamespace and check with hasattr(xp, 'linalg')
171
+ and hasattr(xp, 'fft') as needed.
172
+ """
173
+
174
+ linalg: LinAlg
175
+ fft: FFT
176
+
177
+
178
+ @runtime_checkable
179
+ class Array(Protocol):
180
+ """Protocol for arrays that conform to the Array API standard.
181
+
182
+ This is a lightweight implementation that covers the basic operations
183
+ needed by the FastSIMUS codebase. It will eventually be replaced by one of the
184
+ following:
185
+ https://github.com/magnusdk/spekk/commit/d17d5bbd3e2beac97142a9397ce25942b787a7ed
186
+ https://github.com/data-apis/array-api/pull/589/
187
+ https://github.com/data-apis/array-api-typing
188
+
189
+ Note:
190
+ https://data-apis.org/array-api/latest/API_specification/index.html
191
+ """
192
+
193
+ dtype: Any
194
+ ndim: int
195
+ shape: tuple[int, ...]
196
+ size: int
197
+
198
+ # Basic operations used in the codebase
199
+ def __add__(self, other: int | float | complex | Self) -> Self: ...
200
+ def __sub__(self, other: int | float | complex | Self) -> Self: ...
201
+ def __mul__(self, other: int | float | complex | Self) -> Self: ...
202
+ def __truediv__(self, other: int | float | complex | Self) -> Self: ...
203
+ def __pow__(self, other: int | float | complex | Self) -> Self: ...
204
+ def __matmul__(self, other: Self) -> Self: ...
205
+ def __and__(self, other: Self) -> Self: ...
206
+ def __or__(self, other: Self) -> Self: ...
207
+ def __neg__(self) -> Self: ...
208
+ def __lt__(self, other: int | float | complex | Self) -> Self: ...
209
+ def __gt__(self, other: int | float | complex | Self) -> Self: ...
210
+ def __le__(self, other: int | float | complex | Self) -> Self: ...
211
+ def __ge__(self, other: int | float | complex | Self) -> Self: ...
212
+ def __eq__(self, other: int | float | complex | Self) -> Self: ... # type: ignore[invalid-method-override]
213
+ def __ne__(self, other: int | float | complex | Self) -> Self: ... # type: ignore[invalid-method-override]
214
+ def __getitem__(
215
+ self, key: int | slice | EllipsisType | None | tuple[int | slice | EllipsisType | Self | None, ...] | Self, /
216
+ ) -> Self: ...
217
+
218
+ # Reflected operations for scalar op array
219
+ def __radd__(self, other: int | float | complex | Self) -> Self: ...
220
+ def __rsub__(self, other: int | float | complex | Self) -> Self: ...
221
+ def __rmul__(self, other: int | float | complex | Self) -> Self: ...
222
+ def __rtruediv__(self, other: int | float | complex | Self) -> Self: ...
223
+ def __rpow__(self, other: int | float | complex | Self) -> Self: ...
224
+
225
+ # Only defined for zero-dimensional arrays
226
+ def __float__(self) -> float: ...
227
+ def __int__(self) -> int: ...
228
+ def __bool__(self) -> bool: ...
229
+
230
+
231
+ ArrayOrScalar = Array | int | float | complex | bool
232
+
233
+
234
+ def is_mlx_namespace(xp: object) -> bool:
235
+ """Return True if xp is an MLX namespace (mlx.core or compatible wrapper).
236
+
237
+ MLX is not yet supported by array-api-compat, so we use a name-based check.
238
+ This mirrors the pattern used by array_api_compat.is_jax_namespace() etc.
239
+ """
240
+ return getattr(xp, "__name__", "").startswith("mlx")
241
+
242
+
243
+ def is_cupy_namespace(xp: object) -> bool:
244
+ """Return True if xp is a CuPy namespace (raw cupy or array_api_compat wrapper).
245
+
246
+ array_api_compat wraps cupy as ``array_api_compat.cupy`` whose ``__name__``
247
+ contains ``cupy``; raw ``cupy`` matches the same predicate. Mirrors
248
+ ``is_mlx_namespace`` (CuPy *does* have an array_api_compat wrapper, but a
249
+ string check covers both raw and wrapped variants without an import).
250
+ """
251
+ return "cupy" in getattr(xp, "__name__", "")
252
+
253
+
254
+ def array_namespace(
255
+ *arrays: Any,
256
+ ) -> _ArrayNamespace | _ArrayNamespaceWithLinAlg | _ArrayNamespaceWithFFT | _ArrayNamespaceWithLinAlgAndFFT:
257
+ """Typed wrapper around array_api_compat.array_namespace.
258
+
259
+ Returns the array namespace for the given arrays with proper type hints.
260
+ This resolves static typing issues by providing an ArrayNamespace protocol.
261
+
262
+ Args:
263
+ *arrays: Arrays to get the namespace for
264
+
265
+ Returns:
266
+ The appropriate array namespace (numpy, cupy, jax.numpy, etc.).
267
+ May include optional extensions (linalg, fft, or both).
268
+
269
+ Note:
270
+ Optional extensions may not be available in all array libraries.
271
+ Code should check for their existence using hasattr() before use:
272
+ - hasattr(xp, 'linalg') for linear algebra extension
273
+ - hasattr(xp, 'fft') for FFT extension
274
+
275
+ Examples:
276
+ >>> xp = array_namespace(arr)
277
+ >>> if hasattr(xp, 'linalg'):
278
+ ... norm = xp.linalg.vector_norm(arr)
279
+ ... else:
280
+ ... # Fallback implementation
281
+ ... norm = xp.sqrt(xp.sum(arr * arr, axis=-1))
282
+ >>>
283
+ >>> if hasattr(xp, 'fft'):
284
+ ... spectrum = xp.fft.rfft(signal)
285
+ ... else:
286
+ ... raise RuntimeError("FFT extension not available")
287
+ """
288
+ xp = xpc_array_namespace(*arrays)
289
+ if is_mlx_namespace(xp):
290
+ _ensure_mlx_compat(xp)
291
+ return cast(
292
+ _ArrayNamespace | _ArrayNamespaceWithLinAlg | _ArrayNamespaceWithFFT | _ArrayNamespaceWithLinAlgAndFFT,
293
+ xp,
294
+ )
@@ -0,0 +1,88 @@
1
+ """Geometry calculation utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from math import inf
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from beartype import beartype as typechecker
9
+ from jaxtyping import Float, jaxtyped
10
+
11
+ if TYPE_CHECKING:
12
+ from fast_simus.utils._array_api import _ArrayNamespace
13
+
14
+ # Type alias for Array API objects (until protocol is standardized)
15
+ Array = Any
16
+
17
+
18
+ @jaxtyped(typechecker=typechecker)
19
+ def element_positions(
20
+ n_elements: int,
21
+ pitch: float,
22
+ radius: float,
23
+ xp: _ArrayNamespace,
24
+ ) -> tuple[
25
+ Float[Array, "n_elements 2"],
26
+ Float[Array, " n_elements"] | None,
27
+ float,
28
+ ]:
29
+ """Calculate transducer element positions.
30
+
31
+ Computes the (x, z) positions of transducer elements for both linear
32
+ and convex arrays. For linear arrays, elements are evenly spaced along
33
+ the x-axis. For convex arrays, elements are positioned along an arc
34
+ defined by the radius of curvature.
35
+
36
+ Args:
37
+ n_elements: Number of transducer elements.
38
+ pitch: Element pitch (center-to-center spacing) in meters.
39
+ radius: Curvature radius in meters. Use inf for linear arrays.
40
+ xp: Array namespace for creating arrays in the desired backend.
41
+
42
+ Returns:
43
+ Tuple of (positions, theta_rad, apex_offset_m):
44
+ - positions: Array of (x, z) coordinates in meters. Shape (n_elements, 2).
45
+ positions[:, 0] is lateral (x), positions[:, 1] is axial (z).
46
+ - theta_rad: Array of angular positions in radians for convex arrays,
47
+ None for linear arrays. Shape (n_elements,) or None.
48
+ - apex_offset_m: Distance from array center to arc apex in meters.
49
+ Zero for linear arrays.
50
+ """
51
+ is_linear = radius == inf
52
+
53
+ if is_linear:
54
+ # Linear array: elements evenly spaced along x-axis
55
+ indices = xp.arange(n_elements, dtype=xp.float32)
56
+ x = (indices - (n_elements - 1) / 2) * pitch
57
+ z = xp.zeros(n_elements)
58
+ theta = None
59
+ apex_offset = 0.0
60
+ else:
61
+ # Convex array: elements positioned along arc
62
+ # Compute chord length subtended by the array
63
+ # Each element subtends an angle: pitch / (2 * radius)
64
+ half_angle_per_element = xp.asin(xp.asarray(pitch / 2 / radius))
65
+ total_angle = half_angle_per_element * (n_elements - 1)
66
+ chord = xp.asarray(2 * radius) * xp.sin(total_angle)
67
+
68
+ # Compute apex offset (distance from center to arc apex)
69
+ # h = sqrt(radius^2 - (chord/2)^2)
70
+ apex_offset_arr = xp.sqrt(xp.asarray(radius**2) - (chord / 2) ** 2)
71
+ apex_offset = float(apex_offset_arr)
72
+
73
+ # Compute angular positions using linspace if available, otherwise manual
74
+ theta_start_arr = xp.atan2(-chord / 2, apex_offset_arr)
75
+ theta_start = float(theta_start_arr)
76
+ theta_end_arr = xp.atan2(chord / 2, apex_offset_arr)
77
+ theta_end = float(theta_end_arr)
78
+
79
+ theta = xp.linspace(theta_start, theta_end, n_elements)
80
+
81
+ # Convert angular positions to (x, z) coordinates
82
+ # z = radius * cos(theta) - h (where h is apex_offset)
83
+ z = xp.asarray(radius) * xp.cos(theta) - apex_offset
84
+ x = xp.asarray(radius) * xp.sin(theta)
85
+
86
+ # Stack x and z into a single position array
87
+ positions = xp.stack([x, z], axis=-1)
88
+ return positions, theta, apex_offset