eiko 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eiko/__init__.py ADDED
@@ -0,0 +1,159 @@
1
+ import os
2
+ import sys
3
+
4
+ # ---------------------------------------------------------
5
+ # PATH RESOLUTION
6
+ # ---------------------------------------------------------
7
+ # Resolve paths once for the entire package.
8
+
9
+ # Base directory where this file lives (Eiko/python/eiko/ or .../site-packages/eiko/).
10
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+
12
+ # Option A: Path when the package is installed via pip (src is inside eiko/).
13
+ SRC_DIR_INSTALLED = os.path.join(BASE_DIR, 'src')
14
+
15
+ # Option B: Path when running locally during development (src is sibling to python/).
16
+ SRC_DIR_DEV = os.path.abspath(os.path.join(BASE_DIR, '..', '..', 'src'))
17
+
18
+ # Dynamically choose the folder that actually contains your files
19
+ if os.path.exists(SRC_DIR_INSTALLED):
20
+ SRC_DIR = SRC_DIR_INSTALLED
21
+ elif os.path.exists(SRC_DIR_DEV):
22
+ SRC_DIR = SRC_DIR_DEV
23
+ else:
24
+ raise FileNotFoundError(
25
+ f"Could not find Eiko C++/CUDA source directory. Tried:\n"
26
+ f"1. {SRC_DIR_INSTALLED}\n"
27
+ f"2. {SRC_DIR_DEV}"
28
+ )
29
+
30
+ # ---------------------------------------------------------
31
+ # COMPILER FLAGS
32
+ # ---------------------------------------------------------
33
+ if sys.platform == "win32":
34
+ CXX_ARGS = ['/std:c++20', '/Zc:preprocessor', '/DNOMINMAX']
35
+ NVCC_ARGS = [
36
+ '-std=c++20',
37
+ '-allow-unsupported-compiler',
38
+ '-Xcompiler', '/Zc:preprocessor',
39
+ '-Xcompiler', '/std:c++20',
40
+ '-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH',
41
+ '-DNOMINMAX',
42
+ '--use_fast_math',
43
+ '-arch=native'
44
+ ]
45
+ else:
46
+ CXX_ARGS = ['-std=c++20', '-O3']
47
+ NVCC_ARGS = ['-std=c++20', '--use_fast_math', '-arch=native']
48
+
49
+
50
+ EXTRA_INCLUDE_PATHS = [SRC_DIR]
51
+
52
+ try:
53
+ from torch.utils.cpp_extension import CUDA_HOME
54
+ if CUDA_HOME:
55
+ cccl_base = os.path.join(CUDA_HOME, 'include', 'cccl')
56
+ if os.path.exists(cccl_base):
57
+ EXTRA_INCLUDE_PATHS.extend([
58
+ cccl_base,
59
+ os.path.join(cccl_base, 'thrust'),
60
+ os.path.join(cccl_base, 'libcudacxx', 'include'),
61
+ os.path.join(cccl_base, 'cub'),
62
+ ])
63
+ except ImportError:
64
+ # Let individual wrappers handle the missing torch error cleanly
65
+ pass
66
+
67
+ # ---------------------------------------------------------
68
+ # PUBLIC API EXPORTS
69
+ # ---------------------------------------------------------
70
+
71
+ def _is_jax_array(obj):
72
+ """Safely checks if an object is a JAX array without importing JAX."""
73
+ return type(obj).__module__.startswith('jax')
74
+
75
+ def eiko2d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
76
+ """
77
+ EIKO2D Computes the shortest time-of-flight in an arbitrary 2D medium.
78
+
79
+ Calculates the time-of-flight (u), given a slowness map (f = 1/c), and
80
+ initial conditions (u_init, initialized as infinity at unknown points).
81
+
82
+ EXAMPLE USAGES:
83
+ u = eiko2d(u_init, f) # (Standard usage).
84
+ u = eiko2d(u_init, f, dx=0.5, msfm=True) # (Named arguments).
85
+ u, v_out = eiko2d(u_init, f, v_init=advection_field) # (Advection).
86
+
87
+ REQUIRED INPUTS:
88
+ u_init - Initial conditions (known arrival times/delays).
89
+ Shape: (H, W) for a single image, or (B, H, W) for a batch.
90
+ f - Slowness. Can be (H, W) [broadcast to batch] or (B, H, W).
91
+
92
+ OPTIONAL INPUTS:
93
+ dx - Input grid spacing. Default: 1.0.
94
+ v_init - The initial advection field. Same size as u. Default: None (not used).
95
+ msfm - Whether to enable Multi-Stencil Fast Marching (MSFM).
96
+ Reduces bias along diagonal directions. Default: False.
97
+ gated - Whether to enforce positive propagation along the first
98
+ data dimension. Speeds up computations. It is valid when
99
+ time only increases when moving axially. Default: False.
100
+
101
+ OUTPUTS:
102
+ u - Computed arrival time (time-of-flight) map. Shape matches u_init.
103
+ v - Output advection vectors (returned only if v_init was supplied).
104
+ """
105
+ if _is_jax_array(u_init):
106
+ from .eiko_jax import eiko2d as jax_eiko2d
107
+ return jax_eiko2d(u_init, f, v_init, dx, msfm, gated)
108
+ else:
109
+ from .eiko_torch import eiko2d as pt_eiko2d
110
+ return pt_eiko2d(u_init, f, v_init, dx, msfm, gated)
111
+
112
+ def eiko3d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
113
+ """
114
+ EIKO3D Computes the shortest time-of-flight in an arbitrary 3D medium.
115
+
116
+ Calculates the time-of-flight (u), given a slowness map (f = 1/c), and
117
+ initial conditions (u_init, initialized as infinity at unknown points).
118
+
119
+ EXAMPLE USAGES:
120
+ u = eiko3d(u_init, f) # (Standard usage).
121
+ u = eiko3d(u_init, f, dx=0.5, msfm=True) # (Named arguments).
122
+ u, v_out = eiko3d(u_init, f, v_init=advection_field) # (Advection).
123
+
124
+ REQUIRED INPUTS:
125
+ u_init - Initial conditions (known arrival times/delays).
126
+ Shape: (D, H, W) for a single volume, or (B, D, H, W) for a batch.
127
+ f - Slowness. Can be (D, H, W) [broadcast to batch] or (B, D, H, W).
128
+
129
+ OPTIONAL INPUTS:
130
+ dx - Input grid spacing. Default: 1.0.
131
+ v_init - The initial advection field. Same size as u. Default: None (not used).
132
+ msfm - Whether to enable Multi-Stencil Fast Marching (MSFM).
133
+ Reduces bias along diagonal directions. Default: False.
134
+ gated - Whether to enforce positive propagation along the first
135
+ data dimension. Speeds up computations. It is valid when
136
+ time only increases when moving axially. Default: False.
137
+
138
+ OUTPUTS:
139
+ u - Computed arrival time (time-of-flight) map. Shape matches u_init.
140
+ v - Output advection vectors (returned only if v_init was supplied).
141
+ """
142
+ if _is_jax_array(u_init):
143
+ from .eiko_jax import eiko3d as jax_eiko3d
144
+ return jax_eiko3d(u_init, f, v_init, dx, msfm, gated)
145
+ else:
146
+ from .eiko_torch import eiko3d as pt_eiko3d
147
+ return pt_eiko3d(u_init, f, v_init, dx, msfm, gated)
148
+
149
+ eiko = eiko2d
150
+
151
+ from .animate_eikonal import animate_eikonal
152
+
153
+ # Define what is imported when a user runs `from eiko import *`.
154
+ __all__ = ['eiko', 'eiko3d', 'animate_eikonal']
155
+
156
+ # Define the version number, so it is easily accessible.
157
+ from importlib.metadata import version
158
+ __version__ = version("eiko")
159
+
@@ -0,0 +1,201 @@
1
+ import sys
2
+
3
+ def animate_eikonal(u, v=1.0, color_map='gray', video_filename='',
4
+ title='EIKONAL WAVEFRONT', pulse_width=80.0,
5
+ speed=0.5, overlay=None, outline=None,
6
+ style='real', render_mode='slice'):
7
+ """
8
+ ANIMATE_EIKONAL Visualizes Eikonal equation travel times as an animated wave.
9
+ """
10
+
11
+ # Required dependencies to run the visualization.
12
+ try:
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib.animation import FuncAnimation
16
+ except ImportError:
17
+ raise ImportError(
18
+ "The animation utility requires 'numpy' and 'matplotlib'. "
19
+ "Please install it using: pip install \"eiko[examples]\""
20
+ )
21
+
22
+ # Optional dependency for contours/isosurfaces
23
+ try:
24
+ from skimage import measure
25
+ HAS_SKIMAGE = True
26
+ except ImportError:
27
+ HAS_SKIMAGE = False
28
+ print("[Eiko] Warning: 'scikit-image' not found. Contours, overlays, and 3D isosurfaces will be disabled.", file=sys.stderr)
29
+
30
+ # --- Data Type & Device Checks (e.g., PyTorch -> Numpy) ---
31
+ if hasattr(u, 'cpu'): u = u.detach().cpu().numpy()
32
+ if hasattr(v, 'cpu'): v = v.detach().cpu().numpy()
33
+ if hasattr(overlay, 'cpu'): overlay = overlay.detach().cpu().numpy()
34
+ if hasattr(outline, 'cpu'): outline = outline.detach().cpu().numpy()
35
+
36
+ is_3d = u.ndim == 3
37
+
38
+ # --- Safe Data Initialization ---
39
+ valid_u = u[np.isfinite(u)]
40
+ if len(valid_u) == 0:
41
+ print("Warning: Travel time field u is completely Inf. Nothing to animate.", file=sys.stderr)
42
+ return
43
+
44
+ maxv = np.max(valid_u)
45
+ u_safe = np.copy(u).astype(np.float32)
46
+ u_safe[np.isinf(u_safe)] = maxv + pulse_width
47
+
48
+ # Handle Permutation and setup for 3D
49
+ if is_3d:
50
+ # In Numpy, typical 3D ordering is [Z, Y, X]. We keep it as is,
51
+ # but map slices to dimensions.
52
+ Nz, Ny, Nx = u.shape
53
+ v_anim = v
54
+ else:
55
+ v_anim = v
56
+ Ny, Nx = u.shape
57
+
58
+ # --- Figure & Axis Setup ---
59
+ fig = plt.figure(figsize=(12, 8), facecolor='white')
60
+
61
+ if style.lower() == 'real':
62
+ clim = [-1, 1]
63
+ elif style.lower() == 'abs':
64
+ clim = [0, 1]
65
+ elif style.lower() == 'db':
66
+ clim = [-30, 0]
67
+ db_ref = 20 * np.log10(1.0 + 1e-12) # +1e-12 prevents log10(0)
68
+
69
+ # Setup Render Objects
70
+ if is_3d:
71
+ # 3D Axes
72
+ ax = fig.add_subplot(111, projection='3d')
73
+ ax.set_title(title, fontsize=16, fontweight='bold')
74
+ ax.set_xlabel('X (Lateral)')
75
+ ax.set_ylabel('Y (Elevation)')
76
+ ax.set_zlabel('Z (Depth)')
77
+ ax.invert_zaxis() # Ultrasound convention: Depth increases downwards
78
+
79
+ # We'll handle the 3D drawing inside the update loop, but prepare indices
80
+ sz, sy, sx = Nz//2, Ny//2, Nx//2
81
+ iso_val = -6 if style.lower() == 'db' else 0.5
82
+
83
+ else:
84
+ # 2D Rendering
85
+ ax = fig.add_subplot(111)
86
+ ax.set_title(title, fontsize=16, fontweight='bold', fontname='serif')
87
+ ax.set_xlabel('X (Lateral)', fontsize=14, fontname='serif')
88
+ ax.set_ylabel('Z (Depth)', fontsize=14, fontname='serif')
89
+
90
+ # Initialize empty image
91
+ im = ax.imshow(np.zeros((Ny, Nx)), vmin=clim[0], vmax=clim[1], cmap=color_map, aspect='equal')
92
+
93
+ # In matplotlib, imshow puts (0,0) at top-left, which inherently matches
94
+ # the ultrasound convention (Depth going down).
95
+
96
+ if outline is not None and HAS_SKIMAGE:
97
+ ax.contour(outline, levels=[0.5], colors='k', linewidths=2)
98
+
99
+ if overlay is not None and HAS_SKIMAGE:
100
+ contours = measure.find_contours(overlay, 0.5)
101
+ for contour in contours:
102
+ # contour is (row, col) -> (Y, X)
103
+ ax.fill(contour[:, 1], contour[:, 0], color='red', alpha=0.3, edgecolor='r', linewidth=2.5)
104
+
105
+ freq = 6 * np.pi
106
+
107
+ # --- Animation Update Logic ---
108
+ def compute_frame(t):
109
+ """Calculates the analytical wave packet for a given time step."""
110
+ diff_t = u_safe - t
111
+ valid_mask = np.abs(diff_t) <= (pulse_width / 2.0)
112
+
113
+ img = np.zeros_like(u_safe, dtype=np.complex64)
114
+
115
+ if np.any(valid_mask):
116
+ d_valid = diff_t[valid_mask]
117
+
118
+ # Analytical Hanning envelope
119
+ envelope = 0.5 * (1.0 + np.cos(2 * np.pi * d_valid / pulse_width))
120
+
121
+ # Apply envelope and carrier phase
122
+ img[valid_mask] = envelope * np.exp(1j * freq * d_valid / pulse_width)
123
+
124
+ img = img * v_anim
125
+
126
+ # Apply Style
127
+ if style.lower() == 'real':
128
+ final_img = np.real(img)
129
+ elif style.lower() == 'abs':
130
+ final_img = np.abs(img)
131
+ elif style.lower() == 'db':
132
+ final_img = 20 * np.log10(np.abs(img) + 1e-12) - db_ref
133
+
134
+ return final_img
135
+
136
+ # time_steps = np.arange(0, maxv + pulse_width, speed)
137
+ start_time = np.min(u_safe)
138
+ time_steps = np.arange(start_time, maxv + pulse_width, speed)
139
+
140
+ # Note: Global references for 3D plot collections so they can be removed/updated
141
+ frame_objs = []
142
+
143
+
144
+ def update(frame_idx):
145
+ t = time_steps[frame_idx]
146
+ final_img = compute_frame(t)
147
+
148
+ #if frame_idx % 20 == 0: # Check every 20 frames to avoid spamming
149
+ # print(f"Frame {frame_idx} (t={t:.1f}): Data min/max = {final_img.min():.3f} / {final_img.max():.3f}")
150
+
151
+ if is_3d:
152
+ # Clear all collections (slices or isosurfaces) from the previous frame.
153
+ for coll in ax.collections:
154
+ coll.remove()
155
+
156
+ # Define levels to force contouring (20 steps is usually enough for smooth waves)
157
+ levels = np.linspace(clim[0], clim[1], 20)
158
+ if render_mode == 'slice':
159
+ # Re-draw the slices.
160
+ X, Y = np.meshgrid(np.arange(Nx), np.arange(Ny))
161
+ ax.contourf(X, Y, final_img[sz, :, :], zdir='z', offset=sz, cmap=color_map, vmin=clim[0], vmax=clim[1], levels=levels)
162
+
163
+ Y, Z = np.meshgrid(np.arange(Ny), np.arange(Nz))
164
+ ax.contourf(final_img[:, :, sx], Y, Z, zdir='x', offset=sx, cmap=color_map, vmin=clim[0], vmax=clim[1], levels=levels)
165
+
166
+ X, Z = np.meshgrid(np.arange(Nx), np.arange(Nz))
167
+ ax.contourf(X, final_img[:, sy, :], Z, zdir='y', offset=sy, cmap=color_map, vmin=clim[0], vmax=clim[1], levels=levels)
168
+
169
+ elif render_mode == 'isosurface' and HAS_SKIMAGE:
170
+ try:
171
+ # Use marching cubes to find the isosurface.
172
+ verts, faces, _, _ = measure.marching_cubes(final_img, level=iso_val)
173
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
174
+ mesh = Poly3DCollection(verts[faces], alpha=0.7)
175
+ mesh.set_facecolor([0.2, 0.6, 1.0])
176
+ mesh.set_edgecolor('none')
177
+ ax.add_collection3d(mesh)
178
+ except ValueError:
179
+ pass # Fails cleanly if the volume doesn't contain the isovalue yet
180
+
181
+ return ax.collections
182
+ else:
183
+ im.set_data(final_img)
184
+ return [im]
185
+
186
+ # --- Run / Export Animation ---
187
+ # Blit=True massively speeds up 2D rendering by only redrawing changed pixels.
188
+ # Blit is not well-supported for matplotlib 3D axes.
189
+ anim = FuncAnimation(fig, update, frames=len(time_steps),
190
+ interval=1000/60, blit=not is_3d)
191
+
192
+ if video_filename:
193
+ print(f"Exporting video to {video_filename} (this may take a moment)...")
194
+ # Ensure you have ffmpeg installed on your Ubuntu machine (sudo apt install ffmpeg)
195
+ writer = animation.FFMpegWriter(fps=60, bitrate=2000)
196
+ anim.save(video_filename, writer=writer)
197
+ print("Video export complete.")
198
+ else:
199
+ plt.show()
200
+
201
+ return anim
eiko/eiko_jax.py ADDED
@@ -0,0 +1,326 @@
1
+ import os
2
+ import sys
3
+ import struct
4
+ from functools import partial
5
+ from torch.utils.cpp_extension import load, _get_build_directory
6
+
7
+ try:
8
+ import jax
9
+ import pybind11
10
+ import jaxlib
11
+ except ImportError as e:
12
+ raise ImportError(
13
+ f"\n[Eiko] JAX bindings require 'jax', 'jaxlib', and 'pybind11' to be installed.\n"
14
+ f" Please install the required environment via: pip install \"eiko[jax]\"\n"
15
+ ) from e
16
+
17
+ import jax.numpy as jnp
18
+ from jax import core
19
+ from jax.interpreters import batching
20
+ from jax.interpreters import xla
21
+ from jax.interpreters import mlir
22
+
23
+ # Robust Primitive resolution for modern JAX.
24
+ if not hasattr(core, "Primitive"):
25
+ from jax._src.core import Primitive
26
+ core.Primitive = Primitive
27
+
28
+ # Version-agnostic xla_client import.
29
+ try:
30
+ from jaxlib import xla_client
31
+ except ImportError:
32
+ from jax.lib import xla_client
33
+
34
+ # Version-agnostic custom_call import.
35
+ #try:
36
+ # from jaxlib.hlo_helpers import custom_call
37
+ #except ImportError:
38
+ # from jax.interpreters.mlir import custom_call
39
+
40
+ from eiko import SRC_DIR, CXX_ARGS, NVCC_ARGS, EXTRA_INCLUDE_PATHS
41
+
42
+ # ---------------------------------------------------------
43
+ # JIT COMPILATION & LOADING
44
+ # ---------------------------------------------------------
45
+ build_dir = _get_build_directory('eiko_jax_impl', verbose=False)
46
+ is_cached = os.path.exists(build_dir) and len(os.listdir(build_dir)) > 0
47
+
48
+ if not is_cached:
49
+ print("[Eiko] First-time JAX initialization: JIT Compiling CUDA kernels for your GPU... (This may take a minute)")
50
+ sys.stdout.flush()
51
+
52
+ jax_source = os.path.join(SRC_DIR, 'bindings', 'jax_bindings.cu')
53
+ jax_includes = EXTRA_INCLUDE_PATHS + [pybind11.get_include()]
54
+
55
+ _fim_jax_impl = load(
56
+ name="eiko_jax_impl",
57
+ sources=[jax_source],
58
+ extra_cflags=CXX_ARGS,
59
+ extra_cuda_cflags=NVCC_ARGS,
60
+ extra_include_paths=jax_includes,
61
+ verbose=False
62
+ )
63
+
64
+ for name, target in _fim_jax_impl.registrations().items():
65
+ xla_client.register_custom_call_target(name, target, platform="gpu")
66
+
67
+ # =========================================================
68
+ # 1. JAX PRIMITIVE DEFINITION & MLIR LOWERING
69
+ # =========================================================
70
+ _fim_prim = core.Primitive("jax_fim_solve")
71
+ _fim_prim.multiple_results = False
72
+ _fim_prim.def_impl(partial(xla.apply_primitive, _fim_prim))
73
+
74
+ def _fim_abstract_eval(*args, opaque_data, out_shape, out_dtype):
75
+ return core.ShapedArray(out_shape, out_dtype)
76
+
77
+ _fim_prim.def_abstract_eval(_fim_abstract_eval)
78
+
79
+ def _build_custom_call_agnostic(call_target_name, result_types, operands,
80
+ operand_layouts, result_layouts, backend_config):
81
+ """
82
+ A custom call builder that checks for legacy wrappers
83
+ and falls back to raw MLIR node generation for modern JAX.
84
+ """
85
+ # 1. Try mid-era JAX wrapper (JAX >= 0.4.15 and < 0.4.30)
86
+ try:
87
+ from jaxlib.hlo_helpers import custom_call
88
+ return custom_call(
89
+ call_target_name,
90
+ result_types=result_types,
91
+ operands=operands,
92
+ operand_layouts=operand_layouts,
93
+ result_layouts=result_layouts,
94
+ backend_config=backend_config
95
+ )
96
+ except ImportError:
97
+ pass
98
+
99
+ # 2. Try legacy JAX wrapper (JAX < 0.4.15)
100
+ try:
101
+ from jax.interpreters.mlir import custom_call
102
+ return custom_call(
103
+ call_target_name,
104
+ result_types=result_types,
105
+ operands=operands,
106
+ operand_layouts=operand_layouts,
107
+ result_layouts=result_layouts,
108
+ backend_config=backend_config
109
+ )
110
+ except (ImportError, AttributeError):
111
+ pass
112
+
113
+ # 3. Modern JAX (>= 0.4.30) where helpers are completely removed.
114
+ import jaxlib.mlir.ir as ir
115
+ try:
116
+ from jaxlib.mlir.dialects import mhlo as hlo
117
+ except ImportError:
118
+ from jaxlib.mlir.dialects import hlo
119
+
120
+ def _layout_attr(layouts):
121
+ if layouts is None:
122
+ return None
123
+ import numpy as np
124
+ attr_list = []
125
+ for layout in layouts:
126
+ arr = np.array(layout, dtype=np.int64)
127
+
128
+ # Define the element type as MLIR's 'index' type.
129
+ index_type = ir.IndexType.get()
130
+
131
+ # Define the 1D tensor shape explicitly using the index type.
132
+ tensor_type = ir.RankedTensorType.get(arr.shape, index_type)
133
+
134
+ # Create the attribute with the enforced tensor type.
135
+ attr = ir.DenseIntElementsAttr.get(arr, type=tensor_type)
136
+ attr_list.append(attr)
137
+
138
+ return ir.ArrayAttr.get(attr_list)
139
+
140
+ kwargs = {
141
+ "call_target_name": ir.StringAttr.get(call_target_name),
142
+ "has_side_effect": ir.BoolAttr.get(False),
143
+ "api_version": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2),
144
+ "called_computations": ir.ArrayAttr.get([]),
145
+ }
146
+
147
+ if backend_config is not None:
148
+ kwargs["backend_config"] = ir.StringAttr.get(backend_config)
149
+
150
+ if operand_layouts is not None:
151
+ kwargs["operand_layouts"] = _layout_attr(operand_layouts)
152
+
153
+ if result_layouts is not None:
154
+ kwargs["result_layouts"] = _layout_attr(result_layouts)
155
+
156
+ return hlo.CustomCallOp(result_types, operands, **kwargs)
157
+
158
+ # MLIR lowering rule: The bridge between JAX's Python graph and XLA C++.
159
+ def _fim_lowering(ctx, *args, opaque_data, out_shape, out_dtype):
160
+ import jaxlib.mlir.ir as ir
161
+
162
+ # Convert numpy dtype to MLIR IR type.
163
+ tensor_type = ir.RankedTensorType.get(
164
+ out_shape, mlir.dtype_to_ir_type(out_dtype)
165
+ )
166
+
167
+ operand_layouts = [tuple(range(arg.type.rank)[::-1]) for arg in args]
168
+ result_layouts = [tuple(range(len(out_shape))[::-1])]
169
+
170
+ # MLIR custom call builder.
171
+ call = _build_custom_call_agnostic(
172
+ "jax_fim_solve",
173
+ result_types=[tensor_type],
174
+ operands=args,
175
+ operand_layouts=operand_layouts,
176
+ result_layouts=result_layouts,
177
+ backend_config=opaque_data
178
+ )
179
+ return call.results
180
+
181
+
182
+ mlir.register_lowering(_fim_prim, _fim_lowering, platform="gpu")
183
+
184
+ # =========================================================
185
+ # BATCHING RULE
186
+ # =========================================================
187
+ def _fim_batch_rule(batched_args, batch_dims, *, opaque_data, out_shape, out_dtype):
188
+ # Unpack original configuration.
189
+ unpacked = list(struct.unpack('=iiiifiiiiiii', opaque_data))
190
+
191
+ # Find the new batch size. batched_args[0] is u_init.
192
+ bdim = batch_dims[0] if batch_dims[0] is not None else 0
193
+ new_batch_axis_size = batched_args[0].shape[bdim]
194
+
195
+ # Update the batch_size (index 3 in our struct)
196
+ unpacked[3] = unpacked[3] * new_batch_axis_size
197
+ new_opaque_data = struct.pack('=iiiifiiiiiii', *unpacked)
198
+
199
+ # Push the batch dimension to axis 0 for all batched arguments.
200
+ aligned_args = [
201
+ batching.moveaxis(arg, d, 0) if d is not None else arg
202
+ for arg, d in zip(batched_args, batch_dims)
203
+ ]
204
+
205
+ # Prepend the new batch dimension to the output shape.
206
+ new_out_shape = (new_batch_axis_size,) + tuple(out_shape)
207
+
208
+ out = _fim_prim.bind(
209
+ *aligned_args,
210
+ opaque_data=new_opaque_data,
211
+ out_shape=new_out_shape,
212
+ out_dtype=out_dtype
213
+ )
214
+
215
+ # Tell JAX that the batched dimension of the output is at axis 0.
216
+ return out, 0
217
+
218
+ batching.primitive_batchers[_fim_prim] = _fim_batch_rule
219
+
220
+ # =========================================================
221
+ # 2. BASE XLA CUSTOM CALL
222
+ # =========================================================
223
+ def _fim_custom_call(u_init, f, v, dx, msfm, is_3d, gated_x, is_backward, tof=None):
224
+ u_init = jnp.asarray(u_init)
225
+ f = jnp.asarray(f)
226
+
227
+ has_v = v is not None
228
+ has_tof = tof is not None
229
+ operands = [u_init, f]
230
+
231
+ if has_v:
232
+ v = jnp.asarray(v)
233
+ operands.append(v)
234
+
235
+ if is_backward and has_tof:
236
+ tof = jnp.asarray(tof)
237
+ operands.append(tof)
238
+
239
+ batch_size = u_init.shape[0]
240
+ if is_3d:
241
+ depth, height, width = u_init.shape[-3:]
242
+ else:
243
+ depth = 1
244
+ height, width = u_init.shape[-2:]
245
+
246
+ if f.ndim == u_init.ndim - 1:
247
+ broadcast_f = True
248
+ else:
249
+ broadcast_f = (f.shape[0] == 1 and batch_size > 1)
250
+
251
+ opaque_data = struct.pack(
252
+ '=iiiifiiiiiii',
253
+ width, height, depth, batch_size, float(dx),
254
+ int(is_3d), int(is_backward), int(msfm), int(has_v),
255
+ int(broadcast_f), int(gated_x), int(has_tof)
256
+ )
257
+
258
+ return _fim_prim.bind(
259
+ *operands,
260
+ opaque_data=opaque_data,
261
+ out_shape=u_init.shape,
262
+ out_dtype=u_init.dtype
263
+ )
264
+
265
+ # =========================================================
266
+ # 3. VJP DEFINITION (Autograd support)
267
+ # =========================================================
268
+ @partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6))
269
+ def _solve_eikonal_base(u_init, f, v, dx, msfm, is_3d, gated_x):
270
+ if is_3d is None:
271
+ is_3d = u_init.ndim >= 4 and u_init.shape[-3] > 1
272
+
273
+ return _fim_custom_call(u_init, f, v, dx, msfm, is_3d, gated_x, is_backward=False)
274
+
275
+ def solve_eikonal_fwd(u_init, f, v, dx, msfm, is_3d, gated_x):
276
+ if is_3d is None:
277
+ is_3d = u_init.ndim >= 4 and u_init.shape[-3] > 1
278
+
279
+ u_out = _fim_custom_call(u_init, f, v, dx, msfm, is_3d, gated_x, is_backward=False)
280
+
281
+ res = (u_out, f)
282
+ return u_out, res
283
+
284
+ def solve_eikonal_bwd(v, dx, msfm, is_3d, gated_x, res, grad_u):
285
+ u_out, f = res
286
+ lambda_init = jnp.zeros_like(u_out)
287
+
288
+ lambda_adj = _fim_custom_call(
289
+ lambda_init, grad_u, None, dx, msfm, is_3d, gated_x,
290
+ is_backward=True, tof=u_out
291
+ )
292
+
293
+ grad_u_init = lambda_adj
294
+ grad_f = lambda_adj * f * dx * dx
295
+
296
+ if f.ndim == lambda_adj.ndim - 1:
297
+ grad_f = jnp.sum(grad_f, axis=0)
298
+ elif f.ndim == lambda_adj.ndim and f.shape[0] == 1 and lambda_adj.shape[0] > 1:
299
+ grad_f = jnp.sum(grad_f, axis=0, keepdims=True)
300
+
301
+ return grad_u_init, grad_f
302
+
303
+ _solve_eikonal_base.defvjp(solve_eikonal_fwd, solve_eikonal_bwd)
304
+
305
+ # =========================================================
306
+ # 4. PUBLIC API
307
+ # =========================================================
308
+ def eiko2d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
309
+ if u_init.ndim > 3:
310
+ raise ValueError(f"eiko2d expects a 2D or batched 2D grid (max 3 dims), got {u_init.ndim} dims.")
311
+
312
+ out = _solve_eikonal_base(u_init, f, v_init, dx, msfm, False, gated)
313
+ if v_init is not None:
314
+ return out, v_init
315
+ return out
316
+
317
+ def eiko3d(u_init, f, v_init=None, dx=1.0, msfm=False, gated=False):
318
+ if u_init.ndim < 3:
319
+ raise ValueError(f"eiko3d expects a 3D or batched 3D grid (min 3 dims), got {u_init.ndim} dims.")
320
+
321
+ out = _solve_eikonal_base(u_init, f, v_init, dx, msfm, True, gated)
322
+ if v_init is not None:
323
+ return out, v_init
324
+ return out
325
+
326
+ eiko = eiko2d