eiko 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eiko/__init__.py +159 -0
- eiko/animate_eikonal.py +201 -0
- eiko/eiko_jax.py +326 -0
- eiko/eiko_torch.py +173 -0
- eiko/src/BatchedFIMSolver.cuh +393 -0
- eiko/src/bindings/jax_bindings.cu +142 -0
- eiko/src/bindings/mex_bindings.cu +219 -0
- eiko/src/bindings/torch_bindings.cu +182 -0
- eiko/src/eiko_dispatch.cuh +88 -0
- eiko/src/eiko_kernels.cuh +1448 -0
- eiko-0.8.0.dist-info/METADATA +217 -0
- eiko-0.8.0.dist-info/RECORD +15 -0
- eiko-0.8.0.dist-info/WHEEL +5 -0
- eiko-0.8.0.dist-info/licenses/LICENSE +28 -0
- eiko-0.8.0.dist-info/top_level.txt +1 -0
eiko/__init__.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
# ---------------------------------------------------------
|
|
5
|
+
# PATH RESOLUTION
|
|
6
|
+
# ---------------------------------------------------------
|
|
7
|
+
# Resolve paths once for the entire package.
|
|
8
|
+
|
|
9
|
+
# Base directory where this file lives (Eiko/python/eiko/ or .../site-packages/eiko/).
|
|
10
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
11
|
+
|
|
12
|
+
# Option A: Path when the package is installed via pip (src is inside eiko/).
|
|
13
|
+
SRC_DIR_INSTALLED = os.path.join(BASE_DIR, 'src')
|
|
14
|
+
|
|
15
|
+
# Option B: Path when running locally during development (src is sibling to python/).
|
|
16
|
+
SRC_DIR_DEV = os.path.abspath(os.path.join(BASE_DIR, '..', '..', 'src'))
|
|
17
|
+
|
|
18
|
+
# Dynamically choose the folder that actually contains your files
|
|
19
|
+
if os.path.exists(SRC_DIR_INSTALLED):
|
|
20
|
+
SRC_DIR = SRC_DIR_INSTALLED
|
|
21
|
+
elif os.path.exists(SRC_DIR_DEV):
|
|
22
|
+
SRC_DIR = SRC_DIR_DEV
|
|
23
|
+
else:
|
|
24
|
+
raise FileNotFoundError(
|
|
25
|
+
f"Could not find Eiko C++/CUDA source directory. Tried:\n"
|
|
26
|
+
f"1. {SRC_DIR_INSTALLED}\n"
|
|
27
|
+
f"2. {SRC_DIR_DEV}"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------
|
|
31
|
+
# COMPILER FLAGS
|
|
32
|
+
# ---------------------------------------------------------
|
|
33
|
+
if sys.platform == "win32":
|
|
34
|
+
CXX_ARGS = ['/std:c++20', '/Zc:preprocessor', '/DNOMINMAX']
|
|
35
|
+
NVCC_ARGS = [
|
|
36
|
+
'-std=c++20',
|
|
37
|
+
'-allow-unsupported-compiler',
|
|
38
|
+
'-Xcompiler', '/Zc:preprocessor',
|
|
39
|
+
'-Xcompiler', '/std:c++20',
|
|
40
|
+
'-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH',
|
|
41
|
+
'-DNOMINMAX',
|
|
42
|
+
'--use_fast_math',
|
|
43
|
+
'-arch=native'
|
|
44
|
+
]
|
|
45
|
+
else:
|
|
46
|
+
CXX_ARGS = ['-std=c++20', '-O3']
|
|
47
|
+
NVCC_ARGS = ['-std=c++20', '--use_fast_math', '-arch=native']
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
EXTRA_INCLUDE_PATHS = [SRC_DIR]
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
from torch.utils.cpp_extension import CUDA_HOME
|
|
54
|
+
if CUDA_HOME:
|
|
55
|
+
cccl_base = os.path.join(CUDA_HOME, 'include', 'cccl')
|
|
56
|
+
if os.path.exists(cccl_base):
|
|
57
|
+
EXTRA_INCLUDE_PATHS.extend([
|
|
58
|
+
cccl_base,
|
|
59
|
+
os.path.join(cccl_base, 'thrust'),
|
|
60
|
+
os.path.join(cccl_base, 'libcudacxx', 'include'),
|
|
61
|
+
os.path.join(cccl_base, 'cub'),
|
|
62
|
+
])
|
|
63
|
+
except ImportError:
|
|
64
|
+
# Let individual wrappers handle the missing torch error cleanly
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
# ---------------------------------------------------------
|
|
68
|
+
# PUBLIC API EXPORTS
|
|
69
|
+
# ---------------------------------------------------------
|
|
70
|
+
|
|
71
|
+
def _is_jax_array(obj):
|
|
72
|
+
"""Safely checks if an object is a JAX array without importing JAX."""
|
|
73
|
+
return type(obj).__module__.startswith('jax')
|
|
74
|
+
|
|
75
|
+
def eiko2d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
|
|
76
|
+
"""
|
|
77
|
+
EIKO2D Computes the shortest time-of-flight in an arbitrary 2D medium.
|
|
78
|
+
|
|
79
|
+
Calculates the time-of-flight (u), given a slowness map (f = 1/c), and
|
|
80
|
+
initial conditions (u_init, initialized as infinity at unknown points).
|
|
81
|
+
|
|
82
|
+
EXAMPLE USAGES:
|
|
83
|
+
u = eiko2d(u_init, f) # (Standard usage).
|
|
84
|
+
u = eiko2d(u_init, f, dx=0.5, msfm=True) # (Named arguments).
|
|
85
|
+
u, v_out = eiko2d(u_init, f, v_init=advection_field) # (Advection).
|
|
86
|
+
|
|
87
|
+
REQUIRED INPUTS:
|
|
88
|
+
u_init - Initial conditions (known arrival times/delays).
|
|
89
|
+
Shape: (H, W) for a single image, or (B, H, W) for a batch.
|
|
90
|
+
f - Slowness. Can be (H, W) [broadcast to batch] or (B, H, W).
|
|
91
|
+
|
|
92
|
+
OPTIONAL INPUTS:
|
|
93
|
+
dx - Input grid spacing. Default: 1.0.
|
|
94
|
+
v_init - The initial advection field. Same size as u. Default: None (not used).
|
|
95
|
+
msfm - Whether to enable Multi-Stencil Fast Marching (MSFM).
|
|
96
|
+
Reduces bias along diagonal directions. Default: False.
|
|
97
|
+
gated - Whether to enforce positive propagation along the first
|
|
98
|
+
data dimension. Speeds up computations. It is valid when
|
|
99
|
+
time only increases when moving axially. Default: False.
|
|
100
|
+
|
|
101
|
+
OUTPUTS:
|
|
102
|
+
u - Computed arrival time (time-of-flight) map. Shape matches u_init.
|
|
103
|
+
v - Output advection vectors (returned only if v_init was supplied).
|
|
104
|
+
"""
|
|
105
|
+
if _is_jax_array(u_init):
|
|
106
|
+
from .eiko_jax import eiko2d as jax_eiko2d
|
|
107
|
+
return jax_eiko2d(u_init, f, v_init, dx, msfm, gated)
|
|
108
|
+
else:
|
|
109
|
+
from .eiko_torch import eiko2d as pt_eiko2d
|
|
110
|
+
return pt_eiko2d(u_init, f, v_init, dx, msfm, gated)
|
|
111
|
+
|
|
112
|
+
def eiko3d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
|
|
113
|
+
"""
|
|
114
|
+
EIKO3D Computes the shortest time-of-flight in an arbitrary 3D medium.
|
|
115
|
+
|
|
116
|
+
Calculates the time-of-flight (u), given a slowness map (f = 1/c), and
|
|
117
|
+
initial conditions (u_init, initialized as infinity at unknown points).
|
|
118
|
+
|
|
119
|
+
EXAMPLE USAGES:
|
|
120
|
+
u = eiko3d(u_init, f) # (Standard usage).
|
|
121
|
+
u = eiko3d(u_init, f, dx=0.5, msfm=True) # (Named arguments).
|
|
122
|
+
u, v_out = eiko3d(u_init, f, v_init=advection_field) # (Advection).
|
|
123
|
+
|
|
124
|
+
REQUIRED INPUTS:
|
|
125
|
+
u_init - Initial conditions (known arrival times/delays).
|
|
126
|
+
Shape: (D, H, W) for a single volume, or (B, D, H, W) for a batch.
|
|
127
|
+
f - Slowness. Can be (D, H, W) [broadcast to batch] or (B, D, H, W).
|
|
128
|
+
|
|
129
|
+
OPTIONAL INPUTS:
|
|
130
|
+
dx - Input grid spacing. Default: 1.0.
|
|
131
|
+
v_init - The initial advection field. Same size as u. Default: None (not used).
|
|
132
|
+
msfm - Whether to enable Multi-Stencil Fast Marching (MSFM).
|
|
133
|
+
Reduces bias along diagonal directions. Default: False.
|
|
134
|
+
gated - Whether to enforce positive propagation along the first
|
|
135
|
+
data dimension. Speeds up computations. It is valid when
|
|
136
|
+
time only increases when moving axially. Default: False.
|
|
137
|
+
|
|
138
|
+
OUTPUTS:
|
|
139
|
+
u - Computed arrival time (time-of-flight) map. Shape matches u_init.
|
|
140
|
+
v - Output advection vectors (returned only if v_init was supplied).
|
|
141
|
+
"""
|
|
142
|
+
if _is_jax_array(u_init):
|
|
143
|
+
from .eiko_jax import eiko3d as jax_eiko3d
|
|
144
|
+
return jax_eiko3d(u_init, f, v_init, dx, msfm, gated)
|
|
145
|
+
else:
|
|
146
|
+
from .eiko_torch import eiko3d as pt_eiko3d
|
|
147
|
+
return pt_eiko3d(u_init, f, v_init, dx, msfm, gated)
|
|
148
|
+
|
|
149
|
+
eiko = eiko2d
|
|
150
|
+
|
|
151
|
+
from .animate_eikonal import animate_eikonal
|
|
152
|
+
|
|
153
|
+
# Define what is imported when a user runs `from eiko import *`.
|
|
154
|
+
__all__ = ['eiko', 'eiko3d', 'animate_eikonal']
|
|
155
|
+
|
|
156
|
+
# Define the version number, so it is easily accessible.
|
|
157
|
+
from importlib.metadata import version
|
|
158
|
+
__version__ = version("eiko")
|
|
159
|
+
|
eiko/animate_eikonal.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
def animate_eikonal(u, v=1.0, color_map='gray', video_filename='',
|
|
4
|
+
title='EIKONAL WAVEFRONT', pulse_width=80.0,
|
|
5
|
+
speed=0.5, overlay=None, outline=None,
|
|
6
|
+
style='real', render_mode='slice'):
|
|
7
|
+
"""
|
|
8
|
+
ANIMATE_EIKONAL Visualizes Eikonal equation travel times as an animated wave.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# Required dependencies to run the visualization.
|
|
12
|
+
try:
|
|
13
|
+
import numpy as np
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
from matplotlib.animation import FuncAnimation
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise ImportError(
|
|
18
|
+
"The animation utility requires 'numpy' and 'matplotlib'. "
|
|
19
|
+
"Please install it using: pip install \"eiko[examples]\""
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Optional dependency for contours/isosurfaces
|
|
23
|
+
try:
|
|
24
|
+
from skimage import measure
|
|
25
|
+
HAS_SKIMAGE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
HAS_SKIMAGE = False
|
|
28
|
+
print("[Eiko] Warning: 'scikit-image' not found. Contours, overlays, and 3D isosurfaces will be disabled.", file=sys.stderr)
|
|
29
|
+
|
|
30
|
+
# --- Data Type & Device Checks (e.g., PyTorch -> Numpy) ---
|
|
31
|
+
if hasattr(u, 'cpu'): u = u.detach().cpu().numpy()
|
|
32
|
+
if hasattr(v, 'cpu'): v = v.detach().cpu().numpy()
|
|
33
|
+
if hasattr(overlay, 'cpu'): overlay = overlay.detach().cpu().numpy()
|
|
34
|
+
if hasattr(outline, 'cpu'): outline = outline.detach().cpu().numpy()
|
|
35
|
+
|
|
36
|
+
is_3d = u.ndim == 3
|
|
37
|
+
|
|
38
|
+
# --- Safe Data Initialization ---
|
|
39
|
+
valid_u = u[np.isfinite(u)]
|
|
40
|
+
if len(valid_u) == 0:
|
|
41
|
+
print("Warning: Travel time field u is completely Inf. Nothing to animate.", file=sys.stderr)
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
maxv = np.max(valid_u)
|
|
45
|
+
u_safe = np.copy(u).astype(np.float32)
|
|
46
|
+
u_safe[np.isinf(u_safe)] = maxv + pulse_width
|
|
47
|
+
|
|
48
|
+
# Handle Permutation and setup for 3D
|
|
49
|
+
if is_3d:
|
|
50
|
+
# In Numpy, typical 3D ordering is [Z, Y, X]. We keep it as is,
|
|
51
|
+
# but map slices to dimensions.
|
|
52
|
+
Nz, Ny, Nx = u.shape
|
|
53
|
+
v_anim = v
|
|
54
|
+
else:
|
|
55
|
+
v_anim = v
|
|
56
|
+
Ny, Nx = u.shape
|
|
57
|
+
|
|
58
|
+
# --- Figure & Axis Setup ---
|
|
59
|
+
fig = plt.figure(figsize=(12, 8), facecolor='white')
|
|
60
|
+
|
|
61
|
+
if style.lower() == 'real':
|
|
62
|
+
clim = [-1, 1]
|
|
63
|
+
elif style.lower() == 'abs':
|
|
64
|
+
clim = [0, 1]
|
|
65
|
+
elif style.lower() == 'db':
|
|
66
|
+
clim = [-30, 0]
|
|
67
|
+
db_ref = 20 * np.log10(1.0 + 1e-12) # +1e-12 prevents log10(0)
|
|
68
|
+
|
|
69
|
+
# Setup Render Objects
|
|
70
|
+
if is_3d:
|
|
71
|
+
# 3D Axes
|
|
72
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
73
|
+
ax.set_title(title, fontsize=16, fontweight='bold')
|
|
74
|
+
ax.set_xlabel('X (Lateral)')
|
|
75
|
+
ax.set_ylabel('Y (Elevation)')
|
|
76
|
+
ax.set_zlabel('Z (Depth)')
|
|
77
|
+
ax.invert_zaxis() # Ultrasound convention: Depth increases downwards
|
|
78
|
+
|
|
79
|
+
# We'll handle the 3D drawing inside the update loop, but prepare indices
|
|
80
|
+
sz, sy, sx = Nz//2, Ny//2, Nx//2
|
|
81
|
+
iso_val = -6 if style.lower() == 'db' else 0.5
|
|
82
|
+
|
|
83
|
+
else:
|
|
84
|
+
# 2D Rendering
|
|
85
|
+
ax = fig.add_subplot(111)
|
|
86
|
+
ax.set_title(title, fontsize=16, fontweight='bold', fontname='serif')
|
|
87
|
+
ax.set_xlabel('X (Lateral)', fontsize=14, fontname='serif')
|
|
88
|
+
ax.set_ylabel('Z (Depth)', fontsize=14, fontname='serif')
|
|
89
|
+
|
|
90
|
+
# Initialize empty image
|
|
91
|
+
im = ax.imshow(np.zeros((Ny, Nx)), vmin=clim[0], vmax=clim[1], cmap=color_map, aspect='equal')
|
|
92
|
+
|
|
93
|
+
# In matplotlib, imshow puts (0,0) at top-left, which inherently matches
|
|
94
|
+
# the ultrasound convention (Depth going down).
|
|
95
|
+
|
|
96
|
+
if outline is not None and HAS_SKIMAGE:
|
|
97
|
+
ax.contour(outline, levels=[0.5], colors='k', linewidths=2)
|
|
98
|
+
|
|
99
|
+
if overlay is not None and HAS_SKIMAGE:
|
|
100
|
+
contours = measure.find_contours(overlay, 0.5)
|
|
101
|
+
for contour in contours:
|
|
102
|
+
# contour is (row, col) -> (Y, X)
|
|
103
|
+
ax.fill(contour[:, 1], contour[:, 0], color='red', alpha=0.3, edgecolor='r', linewidth=2.5)
|
|
104
|
+
|
|
105
|
+
freq = 6 * np.pi
|
|
106
|
+
|
|
107
|
+
# --- Animation Update Logic ---
|
|
108
|
+
def compute_frame(t):
|
|
109
|
+
"""Calculates the analytical wave packet for a given time step."""
|
|
110
|
+
diff_t = u_safe - t
|
|
111
|
+
valid_mask = np.abs(diff_t) <= (pulse_width / 2.0)
|
|
112
|
+
|
|
113
|
+
img = np.zeros_like(u_safe, dtype=np.complex64)
|
|
114
|
+
|
|
115
|
+
if np.any(valid_mask):
|
|
116
|
+
d_valid = diff_t[valid_mask]
|
|
117
|
+
|
|
118
|
+
# Analytical Hanning envelope
|
|
119
|
+
envelope = 0.5 * (1.0 + np.cos(2 * np.pi * d_valid / pulse_width))
|
|
120
|
+
|
|
121
|
+
# Apply envelope and carrier phase
|
|
122
|
+
img[valid_mask] = envelope * np.exp(1j * freq * d_valid / pulse_width)
|
|
123
|
+
|
|
124
|
+
img = img * v_anim
|
|
125
|
+
|
|
126
|
+
# Apply Style
|
|
127
|
+
if style.lower() == 'real':
|
|
128
|
+
final_img = np.real(img)
|
|
129
|
+
elif style.lower() == 'abs':
|
|
130
|
+
final_img = np.abs(img)
|
|
131
|
+
elif style.lower() == 'db':
|
|
132
|
+
final_img = 20 * np.log10(np.abs(img) + 1e-12) - db_ref
|
|
133
|
+
|
|
134
|
+
return final_img
|
|
135
|
+
|
|
136
|
+
# time_steps = np.arange(0, maxv + pulse_width, speed)
|
|
137
|
+
start_time = np.min(u_safe)
|
|
138
|
+
time_steps = np.arange(start_time, maxv + pulse_width, speed)
|
|
139
|
+
|
|
140
|
+
# Note: Global references for 3D plot collections so they can be removed/updated
|
|
141
|
+
frame_objs = []
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def update(frame_idx):
|
|
145
|
+
t = time_steps[frame_idx]
|
|
146
|
+
final_img = compute_frame(t)
|
|
147
|
+
|
|
148
|
+
#if frame_idx % 20 == 0: # Check every 20 frames to avoid spamming
|
|
149
|
+
# print(f"Frame {frame_idx} (t={t:.1f}): Data min/max = {final_img.min():.3f} / {final_img.max():.3f}")
|
|
150
|
+
|
|
151
|
+
if is_3d:
|
|
152
|
+
# Clear all collections (slices or isosurfaces) from the previous frame.
|
|
153
|
+
for coll in ax.collections:
|
|
154
|
+
coll.remove()
|
|
155
|
+
|
|
156
|
+
# Define levels to force contouring (20 steps is usually enough for smooth waves)
|
|
157
|
+
levels = np.linspace(clim[0], clim[1], 20)
|
|
158
|
+
if render_mode == 'slice':
|
|
159
|
+
# Re-draw the slices.
|
|
160
|
+
X, Y = np.meshgrid(np.arange(Nx), np.arange(Ny))
|
|
161
|
+
ax.contourf(X, Y, final_img[sz, :, :], zdir='z', offset=sz, cmap=color_map, vmin=clim[0], vmax=clim[1], levels=levels)
|
|
162
|
+
|
|
163
|
+
Y, Z = np.meshgrid(np.arange(Ny), np.arange(Nz))
|
|
164
|
+
ax.contourf(final_img[:, :, sx], Y, Z, zdir='x', offset=sx, cmap=color_map, vmin=clim[0], vmax=clim[1], levels=levels)
|
|
165
|
+
|
|
166
|
+
X, Z = np.meshgrid(np.arange(Nx), np.arange(Nz))
|
|
167
|
+
ax.contourf(X, final_img[:, sy, :], Z, zdir='y', offset=sy, cmap=color_map, vmin=clim[0], vmax=clim[1], levels=levels)
|
|
168
|
+
|
|
169
|
+
elif render_mode == 'isosurface' and HAS_SKIMAGE:
|
|
170
|
+
try:
|
|
171
|
+
# Use marching cubes to find the isosurface.
|
|
172
|
+
verts, faces, _, _ = measure.marching_cubes(final_img, level=iso_val)
|
|
173
|
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
|
174
|
+
mesh = Poly3DCollection(verts[faces], alpha=0.7)
|
|
175
|
+
mesh.set_facecolor([0.2, 0.6, 1.0])
|
|
176
|
+
mesh.set_edgecolor('none')
|
|
177
|
+
ax.add_collection3d(mesh)
|
|
178
|
+
except ValueError:
|
|
179
|
+
pass # Fails cleanly if the volume doesn't contain the isovalue yet
|
|
180
|
+
|
|
181
|
+
return ax.collections
|
|
182
|
+
else:
|
|
183
|
+
im.set_data(final_img)
|
|
184
|
+
return [im]
|
|
185
|
+
|
|
186
|
+
# --- Run / Export Animation ---
|
|
187
|
+
# Blit=True massively speeds up 2D rendering by only redrawing changed pixels.
|
|
188
|
+
# Blit is not well-supported for matplotlib 3D axes.
|
|
189
|
+
anim = FuncAnimation(fig, update, frames=len(time_steps),
|
|
190
|
+
interval=1000/60, blit=not is_3d)
|
|
191
|
+
|
|
192
|
+
if video_filename:
|
|
193
|
+
print(f"Exporting video to {video_filename} (this may take a moment)...")
|
|
194
|
+
# Ensure you have ffmpeg installed on your Ubuntu machine (sudo apt install ffmpeg)
|
|
195
|
+
writer = animation.FFMpegWriter(fps=60, bitrate=2000)
|
|
196
|
+
anim.save(video_filename, writer=writer)
|
|
197
|
+
print("Video export complete.")
|
|
198
|
+
else:
|
|
199
|
+
plt.show()
|
|
200
|
+
|
|
201
|
+
return anim
|
eiko/eiko_jax.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import struct
|
|
4
|
+
from functools import partial
|
|
5
|
+
from torch.utils.cpp_extension import load, _get_build_directory
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import jax
|
|
9
|
+
import pybind11
|
|
10
|
+
import jaxlib
|
|
11
|
+
except ImportError as e:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
f"\n[Eiko] JAX bindings require 'jax', 'jaxlib', and 'pybind11' to be installed.\n"
|
|
14
|
+
f" Please install the required environment via: pip install \"eiko[jax]\"\n"
|
|
15
|
+
) from e
|
|
16
|
+
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from jax import core
|
|
19
|
+
from jax.interpreters import batching
|
|
20
|
+
from jax.interpreters import xla
|
|
21
|
+
from jax.interpreters import mlir
|
|
22
|
+
|
|
23
|
+
# Robust Primitive resolution for modern JAX.
|
|
24
|
+
if not hasattr(core, "Primitive"):
|
|
25
|
+
from jax._src.core import Primitive
|
|
26
|
+
core.Primitive = Primitive
|
|
27
|
+
|
|
28
|
+
# Version-agnostic xla_client import.
|
|
29
|
+
try:
|
|
30
|
+
from jaxlib import xla_client
|
|
31
|
+
except ImportError:
|
|
32
|
+
from jax.lib import xla_client
|
|
33
|
+
|
|
34
|
+
# Version-agnostic custom_call import.
|
|
35
|
+
#try:
|
|
36
|
+
# from jaxlib.hlo_helpers import custom_call
|
|
37
|
+
#except ImportError:
|
|
38
|
+
# from jax.interpreters.mlir import custom_call
|
|
39
|
+
|
|
40
|
+
from eiko import SRC_DIR, CXX_ARGS, NVCC_ARGS, EXTRA_INCLUDE_PATHS
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------
|
|
43
|
+
# JIT COMPILATION & LOADING
|
|
44
|
+
# ---------------------------------------------------------
|
|
45
|
+
build_dir = _get_build_directory('eiko_jax_impl', verbose=False)
|
|
46
|
+
is_cached = os.path.exists(build_dir) and len(os.listdir(build_dir)) > 0
|
|
47
|
+
|
|
48
|
+
if not is_cached:
|
|
49
|
+
print("[Eiko] First-time JAX initialization: JIT Compiling CUDA kernels for your GPU... (This may take a minute)")
|
|
50
|
+
sys.stdout.flush()
|
|
51
|
+
|
|
52
|
+
jax_source = os.path.join(SRC_DIR, 'bindings', 'jax_bindings.cu')
|
|
53
|
+
jax_includes = EXTRA_INCLUDE_PATHS + [pybind11.get_include()]
|
|
54
|
+
|
|
55
|
+
_fim_jax_impl = load(
|
|
56
|
+
name="eiko_jax_impl",
|
|
57
|
+
sources=[jax_source],
|
|
58
|
+
extra_cflags=CXX_ARGS,
|
|
59
|
+
extra_cuda_cflags=NVCC_ARGS,
|
|
60
|
+
extra_include_paths=jax_includes,
|
|
61
|
+
verbose=False
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
for name, target in _fim_jax_impl.registrations().items():
|
|
65
|
+
xla_client.register_custom_call_target(name, target, platform="gpu")
|
|
66
|
+
|
|
67
|
+
# =========================================================
|
|
68
|
+
# 1. JAX PRIMITIVE DEFINITION & MLIR LOWERING
|
|
69
|
+
# =========================================================
|
|
70
|
+
_fim_prim = core.Primitive("jax_fim_solve")
|
|
71
|
+
_fim_prim.multiple_results = False
|
|
72
|
+
_fim_prim.def_impl(partial(xla.apply_primitive, _fim_prim))
|
|
73
|
+
|
|
74
|
+
def _fim_abstract_eval(*args, opaque_data, out_shape, out_dtype):
|
|
75
|
+
return core.ShapedArray(out_shape, out_dtype)
|
|
76
|
+
|
|
77
|
+
_fim_prim.def_abstract_eval(_fim_abstract_eval)
|
|
78
|
+
|
|
79
|
+
def _build_custom_call_agnostic(call_target_name, result_types, operands,
|
|
80
|
+
operand_layouts, result_layouts, backend_config):
|
|
81
|
+
"""
|
|
82
|
+
A custom call builder that checks for legacy wrappers
|
|
83
|
+
and falls back to raw MLIR node generation for modern JAX.
|
|
84
|
+
"""
|
|
85
|
+
# 1. Try mid-era JAX wrapper (JAX >= 0.4.15 and < 0.4.30)
|
|
86
|
+
try:
|
|
87
|
+
from jaxlib.hlo_helpers import custom_call
|
|
88
|
+
return custom_call(
|
|
89
|
+
call_target_name,
|
|
90
|
+
result_types=result_types,
|
|
91
|
+
operands=operands,
|
|
92
|
+
operand_layouts=operand_layouts,
|
|
93
|
+
result_layouts=result_layouts,
|
|
94
|
+
backend_config=backend_config
|
|
95
|
+
)
|
|
96
|
+
except ImportError:
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
# 2. Try legacy JAX wrapper (JAX < 0.4.15)
|
|
100
|
+
try:
|
|
101
|
+
from jax.interpreters.mlir import custom_call
|
|
102
|
+
return custom_call(
|
|
103
|
+
call_target_name,
|
|
104
|
+
result_types=result_types,
|
|
105
|
+
operands=operands,
|
|
106
|
+
operand_layouts=operand_layouts,
|
|
107
|
+
result_layouts=result_layouts,
|
|
108
|
+
backend_config=backend_config
|
|
109
|
+
)
|
|
110
|
+
except (ImportError, AttributeError):
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
# 3. Modern JAX (>= 0.4.30) where helpers are completely removed.
|
|
114
|
+
import jaxlib.mlir.ir as ir
|
|
115
|
+
try:
|
|
116
|
+
from jaxlib.mlir.dialects import mhlo as hlo
|
|
117
|
+
except ImportError:
|
|
118
|
+
from jaxlib.mlir.dialects import hlo
|
|
119
|
+
|
|
120
|
+
def _layout_attr(layouts):
|
|
121
|
+
if layouts is None:
|
|
122
|
+
return None
|
|
123
|
+
import numpy as np
|
|
124
|
+
attr_list = []
|
|
125
|
+
for layout in layouts:
|
|
126
|
+
arr = np.array(layout, dtype=np.int64)
|
|
127
|
+
|
|
128
|
+
# Define the element type as MLIR's 'index' type.
|
|
129
|
+
index_type = ir.IndexType.get()
|
|
130
|
+
|
|
131
|
+
# Define the 1D tensor shape explicitly using the index type.
|
|
132
|
+
tensor_type = ir.RankedTensorType.get(arr.shape, index_type)
|
|
133
|
+
|
|
134
|
+
# Create the attribute with the enforced tensor type.
|
|
135
|
+
attr = ir.DenseIntElementsAttr.get(arr, type=tensor_type)
|
|
136
|
+
attr_list.append(attr)
|
|
137
|
+
|
|
138
|
+
return ir.ArrayAttr.get(attr_list)
|
|
139
|
+
|
|
140
|
+
kwargs = {
|
|
141
|
+
"call_target_name": ir.StringAttr.get(call_target_name),
|
|
142
|
+
"has_side_effect": ir.BoolAttr.get(False),
|
|
143
|
+
"api_version": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2),
|
|
144
|
+
"called_computations": ir.ArrayAttr.get([]),
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if backend_config is not None:
|
|
148
|
+
kwargs["backend_config"] = ir.StringAttr.get(backend_config)
|
|
149
|
+
|
|
150
|
+
if operand_layouts is not None:
|
|
151
|
+
kwargs["operand_layouts"] = _layout_attr(operand_layouts)
|
|
152
|
+
|
|
153
|
+
if result_layouts is not None:
|
|
154
|
+
kwargs["result_layouts"] = _layout_attr(result_layouts)
|
|
155
|
+
|
|
156
|
+
return hlo.CustomCallOp(result_types, operands, **kwargs)
|
|
157
|
+
|
|
158
|
+
# MLIR lowering rule: The bridge between JAX's Python graph and XLA C++.
|
|
159
|
+
def _fim_lowering(ctx, *args, opaque_data, out_shape, out_dtype):
|
|
160
|
+
import jaxlib.mlir.ir as ir
|
|
161
|
+
|
|
162
|
+
# Convert numpy dtype to MLIR IR type.
|
|
163
|
+
tensor_type = ir.RankedTensorType.get(
|
|
164
|
+
out_shape, mlir.dtype_to_ir_type(out_dtype)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
operand_layouts = [tuple(range(arg.type.rank)[::-1]) for arg in args]
|
|
168
|
+
result_layouts = [tuple(range(len(out_shape))[::-1])]
|
|
169
|
+
|
|
170
|
+
# MLIR custom call builder.
|
|
171
|
+
call = _build_custom_call_agnostic(
|
|
172
|
+
"jax_fim_solve",
|
|
173
|
+
result_types=[tensor_type],
|
|
174
|
+
operands=args,
|
|
175
|
+
operand_layouts=operand_layouts,
|
|
176
|
+
result_layouts=result_layouts,
|
|
177
|
+
backend_config=opaque_data
|
|
178
|
+
)
|
|
179
|
+
return call.results
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
mlir.register_lowering(_fim_prim, _fim_lowering, platform="gpu")
|
|
183
|
+
|
|
184
|
+
# =========================================================
|
|
185
|
+
# BATCHING RULE
|
|
186
|
+
# =========================================================
|
|
187
|
+
def _fim_batch_rule(batched_args, batch_dims, *, opaque_data, out_shape, out_dtype):
|
|
188
|
+
# Unpack original configuration.
|
|
189
|
+
unpacked = list(struct.unpack('=iiiifiiiiiii', opaque_data))
|
|
190
|
+
|
|
191
|
+
# Find the new batch size. batched_args[0] is u_init.
|
|
192
|
+
bdim = batch_dims[0] if batch_dims[0] is not None else 0
|
|
193
|
+
new_batch_axis_size = batched_args[0].shape[bdim]
|
|
194
|
+
|
|
195
|
+
# Update the batch_size (index 3 in our struct)
|
|
196
|
+
unpacked[3] = unpacked[3] * new_batch_axis_size
|
|
197
|
+
new_opaque_data = struct.pack('=iiiifiiiiiii', *unpacked)
|
|
198
|
+
|
|
199
|
+
# Push the batch dimension to axis 0 for all batched arguments.
|
|
200
|
+
aligned_args = [
|
|
201
|
+
batching.moveaxis(arg, d, 0) if d is not None else arg
|
|
202
|
+
for arg, d in zip(batched_args, batch_dims)
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
# Prepend the new batch dimension to the output shape.
|
|
206
|
+
new_out_shape = (new_batch_axis_size,) + tuple(out_shape)
|
|
207
|
+
|
|
208
|
+
out = _fim_prim.bind(
|
|
209
|
+
*aligned_args,
|
|
210
|
+
opaque_data=new_opaque_data,
|
|
211
|
+
out_shape=new_out_shape,
|
|
212
|
+
out_dtype=out_dtype
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Tell JAX that the batched dimension of the output is at axis 0.
|
|
216
|
+
return out, 0
|
|
217
|
+
|
|
218
|
+
batching.primitive_batchers[_fim_prim] = _fim_batch_rule
|
|
219
|
+
|
|
220
|
+
# =========================================================
|
|
221
|
+
# 2. BASE XLA CUSTOM CALL
|
|
222
|
+
# =========================================================
|
|
223
|
+
def _fim_custom_call(u_init, f, v, dx, msfm, is_3d, gated_x, is_backward, tof=None):
|
|
224
|
+
u_init = jnp.asarray(u_init)
|
|
225
|
+
f = jnp.asarray(f)
|
|
226
|
+
|
|
227
|
+
has_v = v is not None
|
|
228
|
+
has_tof = tof is not None
|
|
229
|
+
operands = [u_init, f]
|
|
230
|
+
|
|
231
|
+
if has_v:
|
|
232
|
+
v = jnp.asarray(v)
|
|
233
|
+
operands.append(v)
|
|
234
|
+
|
|
235
|
+
if is_backward and has_tof:
|
|
236
|
+
tof = jnp.asarray(tof)
|
|
237
|
+
operands.append(tof)
|
|
238
|
+
|
|
239
|
+
batch_size = u_init.shape[0]
|
|
240
|
+
if is_3d:
|
|
241
|
+
depth, height, width = u_init.shape[-3:]
|
|
242
|
+
else:
|
|
243
|
+
depth = 1
|
|
244
|
+
height, width = u_init.shape[-2:]
|
|
245
|
+
|
|
246
|
+
if f.ndim == u_init.ndim - 1:
|
|
247
|
+
broadcast_f = True
|
|
248
|
+
else:
|
|
249
|
+
broadcast_f = (f.shape[0] == 1 and batch_size > 1)
|
|
250
|
+
|
|
251
|
+
opaque_data = struct.pack(
|
|
252
|
+
'=iiiifiiiiiii',
|
|
253
|
+
width, height, depth, batch_size, float(dx),
|
|
254
|
+
int(is_3d), int(is_backward), int(msfm), int(has_v),
|
|
255
|
+
int(broadcast_f), int(gated_x), int(has_tof)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return _fim_prim.bind(
|
|
259
|
+
*operands,
|
|
260
|
+
opaque_data=opaque_data,
|
|
261
|
+
out_shape=u_init.shape,
|
|
262
|
+
out_dtype=u_init.dtype
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# =========================================================
|
|
266
|
+
# 3. VJP DEFINITION (Autograd support)
|
|
267
|
+
# =========================================================
|
|
268
|
+
@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6))
|
|
269
|
+
def _solve_eikonal_base(u_init, f, v, dx, msfm, is_3d, gated_x):
|
|
270
|
+
if is_3d is None:
|
|
271
|
+
is_3d = u_init.ndim >= 4 and u_init.shape[-3] > 1
|
|
272
|
+
|
|
273
|
+
return _fim_custom_call(u_init, f, v, dx, msfm, is_3d, gated_x, is_backward=False)
|
|
274
|
+
|
|
275
|
+
def solve_eikonal_fwd(u_init, f, v, dx, msfm, is_3d, gated_x):
|
|
276
|
+
if is_3d is None:
|
|
277
|
+
is_3d = u_init.ndim >= 4 and u_init.shape[-3] > 1
|
|
278
|
+
|
|
279
|
+
u_out = _fim_custom_call(u_init, f, v, dx, msfm, is_3d, gated_x, is_backward=False)
|
|
280
|
+
|
|
281
|
+
res = (u_out, f)
|
|
282
|
+
return u_out, res
|
|
283
|
+
|
|
284
|
+
def solve_eikonal_bwd(v, dx, msfm, is_3d, gated_x, res, grad_u):
|
|
285
|
+
u_out, f = res
|
|
286
|
+
lambda_init = jnp.zeros_like(u_out)
|
|
287
|
+
|
|
288
|
+
lambda_adj = _fim_custom_call(
|
|
289
|
+
lambda_init, grad_u, None, dx, msfm, is_3d, gated_x,
|
|
290
|
+
is_backward=True, tof=u_out
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
grad_u_init = lambda_adj
|
|
294
|
+
grad_f = lambda_adj * f * dx * dx
|
|
295
|
+
|
|
296
|
+
if f.ndim == lambda_adj.ndim - 1:
|
|
297
|
+
grad_f = jnp.sum(grad_f, axis=0)
|
|
298
|
+
elif f.ndim == lambda_adj.ndim and f.shape[0] == 1 and lambda_adj.shape[0] > 1:
|
|
299
|
+
grad_f = jnp.sum(grad_f, axis=0, keepdims=True)
|
|
300
|
+
|
|
301
|
+
return grad_u_init, grad_f
|
|
302
|
+
|
|
303
|
+
_solve_eikonal_base.defvjp(solve_eikonal_fwd, solve_eikonal_bwd)
|
|
304
|
+
|
|
305
|
+
# =========================================================
|
|
306
|
+
# 4. PUBLIC API
|
|
307
|
+
# =========================================================
|
|
308
|
+
def eiko2d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
|
|
309
|
+
if u_init.ndim > 3:
|
|
310
|
+
raise ValueError(f"eiko2d expects a 2D or batched 2D grid (max 3 dims), got {u_init.ndim} dims.")
|
|
311
|
+
|
|
312
|
+
out = _solve_eikonal_base(u_init, f, v_init, dx, msfm, False, gated)
|
|
313
|
+
if v_init is not None:
|
|
314
|
+
return out, v_init
|
|
315
|
+
return out
|
|
316
|
+
|
|
317
|
+
def eiko3d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
|
|
318
|
+
if u_init.ndim < 3:
|
|
319
|
+
raise ValueError(f"eiko3d expects a 3D or batched 3D grid (min 3 dims), got {u_init.ndim} dims.")
|
|
320
|
+
|
|
321
|
+
out = _solve_eikonal_base(u_init, f, v_init, dx, msfm, True, gated)
|
|
322
|
+
if v_init is not None:
|
|
323
|
+
return out, v_init
|
|
324
|
+
return out
|
|
325
|
+
|
|
326
|
+
eiko = eiko2d
|