vmec-jax 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. vmec_jax-0.0.1/LICENSE +21 -0
  2. vmec_jax-0.0.1/PKG-INFO +244 -0
  3. vmec_jax-0.0.1/README.md +213 -0
  4. vmec_jax-0.0.1/pyproject.toml +108 -0
  5. vmec_jax-0.0.1/setup.cfg +4 -0
  6. vmec_jax-0.0.1/tests/test_bcovar_lambda_axis_closure.py +78 -0
  7. vmec_jax-0.0.1/tests/test_bmag_parity.py +99 -0
  8. vmec_jax-0.0.1/tests/test_booz_input.py +32 -0
  9. vmec_jax-0.0.1/tests/test_boundary_eval.py +125 -0
  10. vmec_jax-0.0.1/tests/test_bsup_parity.py +122 -0
  11. vmec_jax-0.0.1/tests/test_c_kernels_axisym.py +46 -0
  12. vmec_jax-0.0.1/tests/test_chipf_inversion.py +40 -0
  13. vmec_jax-0.0.1/tests/test_chipf_mesh_convention.py +81 -0
  14. vmec_jax-0.0.1/tests/test_cli_vmec2000_exec.py +99 -0
  15. vmec_jax-0.0.1/tests/test_constraint_pipeline.py +168 -0
  16. vmec_jax-0.0.1/tests/test_coords_kernel.py +41 -0
  17. vmec_jax-0.0.1/tests/test_current_driven_iota_helper.py +43 -0
  18. vmec_jax-0.0.1/tests/test_driver_api.py +1196 -0
  19. vmec_jax-0.0.1/tests/test_dump_helpers.py +254 -0
  20. vmec_jax-0.0.1/tests/test_end_to_end_vmec_residual_gn.py +62 -0
  21. vmec_jax-0.0.1/tests/test_energy_integrals_parity.py +104 -0
  22. vmec_jax-0.0.1/tests/test_equif_eqfor_parity.py +136 -0
  23. vmec_jax-0.0.1/tests/test_fixaray_trig_tables.py +80 -0
  24. vmec_jax-0.0.1/tests/test_force_norms_dynamic_parity.py +85 -0
  25. vmec_jax-0.0.1/tests/test_free_boundary_wp0.py +1583 -0
  26. vmec_jax-0.0.1/tests/test_freeb_scalpot_compare_parser.py +203 -0
  27. vmec_jax-0.0.1/tests/test_geom_metrics.py +37 -0
  28. vmec_jax-0.0.1/tests/test_getfsq_block_sums.py +85 -0
  29. vmec_jax-0.0.1/tests/test_implicit_helpers.py +241 -0
  30. vmec_jax-0.0.1/tests/test_init_guess.py +56 -0
  31. vmec_jax-0.0.1/tests/test_multigrid_interp.py +90 -0
  32. vmec_jax-0.0.1/tests/test_namelist.py +49 -0
  33. vmec_jax-0.0.1/tests/test_nonaxis_exec_stage_trace_parity.py +52 -0
  34. vmec_jax-0.0.1/tests/test_optimization_helpers.py +51 -0
  35. vmec_jax-0.0.1/tests/test_parity_axis_rules.py +47 -0
  36. vmec_jax-0.0.1/tests/test_parity_sweep_manifest_thresholds.py +71 -0
  37. vmec_jax-0.0.1/tests/test_plotting_helpers.py +92 -0
  38. vmec_jax-0.0.1/tests/test_qi_wout_parity.py +84 -0
  39. vmec_jax-0.0.1/tests/test_regression_cases.py +180 -0
  40. vmec_jax-0.0.1/tests/test_residue_getfsq_parity.py +101 -0
  41. vmec_jax-0.0.1/tests/test_resume_state.py +105 -0
  42. vmec_jax-0.0.1/tests/test_solve_adaptive_controls.py +47 -0
  43. vmec_jax-0.0.1/tests/test_solve_hotpaths.py +128 -0
  44. vmec_jax-0.0.1/tests/test_solve_scan_chunking.py +87 -0
  45. vmec_jax-0.0.1/tests/test_step3_profiles.py +76 -0
  46. vmec_jax-0.0.1/tests/test_step4_field_cartesian.py +31 -0
  47. vmec_jax-0.0.1/tests/test_step4_field_energy.py +93 -0
  48. vmec_jax-0.0.1/tests/test_step5_solve_lambda.py +124 -0
  49. vmec_jax-0.0.1/tests/test_step6_solve_fixed_boundary.py +81 -0
  50. vmec_jax-0.0.1/tests/test_step8_preconditioner.py +127 -0
  51. vmec_jax-0.0.1/tests/test_step8_wout_stationarity.py +106 -0
  52. vmec_jax-0.0.1/tests/test_step9_implicit_fixed_boundary.py +81 -0
  53. vmec_jax-0.0.1/tests/test_step9_implicit_lambda.py +60 -0
  54. vmec_jax-0.0.1/tests/test_tcon_precondn_diag.py +408 -0
  55. vmec_jax-0.0.1/tests/test_visualization_vtk.py +40 -0
  56. vmec_jax-0.0.1/tests/test_vmec2000_exec_qa_regression.py +112 -0
  57. vmec_jax-0.0.1/tests/test_vmec2000_exec_qh_reactorscale_multigrid_override.py +53 -0
  58. vmec_jax-0.0.1/tests/test_vmec2000_exec_threed1.py +29 -0
  59. vmec_jax-0.0.1/tests/test_vmec2000_python_api_smoke.py +123 -0
  60. vmec_jax-0.0.1/tests/test_vmec2000_scalars_parity.py +153 -0
  61. vmec_jax-0.0.1/tests/test_vmec_alias_gcon.py +191 -0
  62. vmec_jax-0.0.1/tests/test_vmec_bcovar_smoke.py +52 -0
  63. vmec_jax-0.0.1/tests/test_vmec_forces_freeb_edge.py +50 -0
  64. vmec_jax-0.0.1/tests/test_vmec_jacobian_parity.py +118 -0
  65. vmec_jax-0.0.1/tests/test_vmec_parity_host.py +105 -0
  66. vmec_jax-0.0.1/tests/test_vmec_realspace.py +54 -0
  67. vmec_jax-0.0.1/tests/test_vmec_realspace_geom.py +114 -0
  68. vmec_jax-0.0.1/tests/test_vmec_tomnsp_tables.py +44 -0
  69. vmec_jax-0.0.1/tests/test_vmecpp_restart_flag.py +25 -0
  70. vmec_jax-0.0.1/tests/test_wout_parity_reference.py +194 -0
  71. vmec_jax-0.0.1/tests/test_wout_roundtrip.py +81 -0
  72. vmec_jax-0.0.1/tests/test_wout_vmecplot2_compat.py +91 -0
  73. vmec_jax-0.0.1/vmec_jax/__init__.py +292 -0
  74. vmec_jax-0.0.1/vmec_jax/__main__.py +12 -0
  75. vmec_jax-0.0.1/vmec_jax/_compat.py +211 -0
  76. vmec_jax-0.0.1/vmec_jax/api.py +86 -0
  77. vmec_jax-0.0.1/vmec_jax/booz_input.py +569 -0
  78. vmec_jax-0.0.1/vmec_jax/boundary.py +758 -0
  79. vmec_jax-0.0.1/vmec_jax/cli.py +245 -0
  80. vmec_jax-0.0.1/vmec_jax/config.py +161 -0
  81. vmec_jax-0.0.1/vmec_jax/coords.py +140 -0
  82. vmec_jax-0.0.1/vmec_jax/diagnostics.py +368 -0
  83. vmec_jax-0.0.1/vmec_jax/driver.py +2899 -0
  84. vmec_jax-0.0.1/vmec_jax/energy.py +279 -0
  85. vmec_jax-0.0.1/vmec_jax/field.py +408 -0
  86. vmec_jax-0.0.1/vmec_jax/fieldlines.py +148 -0
  87. vmec_jax-0.0.1/vmec_jax/fourier.py +307 -0
  88. vmec_jax-0.0.1/vmec_jax/free_boundary.py +3512 -0
  89. vmec_jax-0.0.1/vmec_jax/geom.py +335 -0
  90. vmec_jax-0.0.1/vmec_jax/grids.py +59 -0
  91. vmec_jax-0.0.1/vmec_jax/implicit.py +2067 -0
  92. vmec_jax-0.0.1/vmec_jax/init_guess.py +1225 -0
  93. vmec_jax-0.0.1/vmec_jax/integrals.py +184 -0
  94. vmec_jax-0.0.1/vmec_jax/modes.py +131 -0
  95. vmec_jax-0.0.1/vmec_jax/multigrid.py +290 -0
  96. vmec_jax-0.0.1/vmec_jax/namelist.py +235 -0
  97. vmec_jax-0.0.1/vmec_jax/nyquist.py +63 -0
  98. vmec_jax-0.0.1/vmec_jax/optimization.py +224 -0
  99. vmec_jax-0.0.1/vmec_jax/plotting.py +931 -0
  100. vmec_jax-0.0.1/vmec_jax/preconditioner_1d.py +558 -0
  101. vmec_jax-0.0.1/vmec_jax/preconditioner_1d_jax.py +1759 -0
  102. vmec_jax-0.0.1/vmec_jax/profiles.py +249 -0
  103. vmec_jax-0.0.1/vmec_jax/radial.py +53 -0
  104. vmec_jax-0.0.1/vmec_jax/residuals.py +196 -0
  105. vmec_jax-0.0.1/vmec_jax/solve.py +13893 -0
  106. vmec_jax-0.0.1/vmec_jax/state.py +170 -0
  107. vmec_jax-0.0.1/vmec_jax/static.py +255 -0
  108. vmec_jax-0.0.1/vmec_jax/visualization.py +352 -0
  109. vmec_jax-0.0.1/vmec_jax/vmec2000_exec.py +278 -0
  110. vmec_jax-0.0.1/vmec_jax/vmec_bcovar.py +1066 -0
  111. vmec_jax-0.0.1/vmec_jax/vmec_constraints.py +524 -0
  112. vmec_jax-0.0.1/vmec_jax/vmec_forces.py +1852 -0
  113. vmec_jax-0.0.1/vmec_jax/vmec_jacobian.py +317 -0
  114. vmec_jax-0.0.1/vmec_jax/vmec_lforbal.py +408 -0
  115. vmec_jax-0.0.1/vmec_jax/vmec_numpy_forces.py +938 -0
  116. vmec_jax-0.0.1/vmec_jax/vmec_parity.py +883 -0
  117. vmec_jax-0.0.1/vmec_jax/vmec_realspace.py +740 -0
  118. vmec_jax-0.0.1/vmec_jax/vmec_residue.py +900 -0
  119. vmec_jax-0.0.1/vmec_jax/vmec_tomnsp.py +1446 -0
  120. vmec_jax-0.0.1/vmec_jax/wout.py +6254 -0
  121. vmec_jax-0.0.1/vmec_jax.egg-info/PKG-INFO +244 -0
  122. vmec_jax-0.0.1/vmec_jax.egg-info/SOURCES.txt +124 -0
  123. vmec_jax-0.0.1/vmec_jax.egg-info/dependency_links.txt +1 -0
  124. vmec_jax-0.0.1/vmec_jax.egg-info/entry_points.txt +3 -0
  125. vmec_jax-0.0.1/vmec_jax.egg-info/requires.txt +27 -0
  126. vmec_jax-0.0.1/vmec_jax.egg-info/top_level.txt +3 -0
vmec_jax-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 UW Plasma
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,244 @@
1
+ Metadata-Version: 2.4
2
+ Name: vmec-jax
3
+ Version: 0.0.1
4
+ Summary: End-to-end differentiable JAX implementation of VMEC2000 for fixed and free-boundary equilibria.
5
+ Author: vmec_jax contributors
6
+ License-Expression: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: numpy
11
+ Requires-Dist: jax
12
+ Requires-Dist: jaxlib
13
+ Requires-Dist: netCDF4
14
+ Requires-Dist: tomli; python_version < "3.11"
15
+ Provides-Extra: jax
16
+ Requires-Dist: jax; extra == "jax"
17
+ Requires-Dist: jaxlib; extra == "jax"
18
+ Provides-Extra: netcdf
19
+ Requires-Dist: netCDF4; extra == "netcdf"
20
+ Provides-Extra: docs
21
+ Requires-Dist: sphinx; extra == "docs"
22
+ Requires-Dist: furo; extra == "docs"
23
+ Provides-Extra: plots
24
+ Requires-Dist: matplotlib; extra == "plots"
25
+ Provides-Extra: dev
26
+ Requires-Dist: pytest; extra == "dev"
27
+ Requires-Dist: ruff; extra == "dev"
28
+ Requires-Dist: mypy; extra == "dev"
29
+ Requires-Dist: types-setuptools; extra == "dev"
30
+ Dynamic: license-file
31
+
32
+ # vmec-jax
33
+
34
+ End-to-end differentiable JAX implementation of **VMEC2000** for fixed-boundary
35
+ and free-boundary ideal-MHD equilibria.
36
+
37
+ ## Showcase (single-grid)
38
+
39
+ All figures below use the same **single-grid** run settings: `NS_ARRAY=151`, `NITER_ARRAY=5000`, `FTOL_ARRAY=1e-14`, `NSTEP=500`.
40
+
41
+ <table>
42
+ <tr>
43
+ <td><img src="docs/_static/figures/axisym_compare_cross_sections.png" width="420" /></td>
44
+ <td><img src="docs/_static/figures/qa_compare_cross_sections.png" width="420" /></td>
45
+ </tr>
46
+ <tr>
47
+ <td align="center"><code>ITERModel</code> cross-section (VMEC2000 vs vmec_jax)</td>
48
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> cross-section (VMEC2000 vs vmec_jax)</td>
49
+ </tr>
50
+ <tr>
51
+ <td><img src="docs/_static/figures/axisym_compare_iota.png" width="420" /></td>
52
+ <td><img src="docs/_static/figures/qa_compare_iota.png" width="420" /></td>
53
+ </tr>
54
+ <tr>
55
+ <td align="center"><code>ITERModel</code> iota (VMEC2000 vs vmec_jax)</td>
56
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> iota (VMEC2000 vs vmec_jax)</td>
57
+ </tr>
58
+ </table>
59
+
60
+ <p align="center">
61
+ <img src="docs/_static/figures/readme_fsq_trace_single_grid.png" width="860" />
62
+ </p>
63
+
64
+ <p align="center">
65
+ <img src="docs/_static/figures/readme_runtime_compare.png" width="860" />
66
+ </p>
67
+
68
+ **Cold vs warm runtime**: the *cold* bar includes XLA JIT compilation on the first call (one-time cost per process); the *warm* bar is the steady-state solve time for all subsequent calls in the same process, with the compiled kernels already in-memory. VMEC2000 is a pre-compiled Fortran binary and therefore has no compilation overhead — it is always effectively "cold". The warm vmec_jax time is the fair comparison for repeated solves (e.g., in an optimization loop). Starting from v0.2, vmec_jax automatically caches compiled XLA kernels to disk (`~/.cache/vmec_jax/jax_cache`), so that *cold* runs in a fresh process on the same machine after the first invocation benefit from the on-disk cache and approach warm-run speed.
69
+
70
+ ## More visuals (single-grid)
71
+
72
+ <table>
73
+ <tr>
74
+ <td><img src="docs/_static/figures/axisym_compare_3d.png" width="420" /></td>
75
+ <td><img src="docs/_static/figures/qa_compare_3d.png" width="420" /></td>
76
+ </tr>
77
+ <tr>
78
+ <td align="center"><code>ITERModel</code> 3D LCFS (VMEC2000 vs vmec_jax)</td>
79
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> 3D LCFS (VMEC2000 vs vmec_jax)</td>
80
+ </tr>
81
+ <tr>
82
+ <td><img src="docs/_static/figures/axisym_compare_bmag_surface.png" width="420" /></td>
83
+ <td><img src="docs/_static/figures/qa_compare_bmag_surface.png" width="420" /></td>
84
+ </tr>
85
+ <tr>
86
+ <td align="center"><code>ITERModel</code> |B| on LCFS (VMEC2000 vs vmec_jax)</td>
87
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> |B| on LCFS (VMEC2000 vs vmec_jax)</td>
88
+ </tr>
89
+ </table>
90
+
91
+ ## What it is
92
+
93
+ - VMEC2000-parity solver for fixed-boundary and free-boundary equilibria.
94
+ - Supports axisymmetric and non-axisymmetric configurations, with `lasym=False` and `lasym=True` for stellarator symmetry/asymmetry and up-down symmetry/asymmetry.
95
+ - Default CLI path is `vmec_jax input.name`.
96
+ - `wout_*.nc` outputs, iteration diagnostics, and manifest-based parity sweeps are built around VMEC2000-compatible workflows.
97
+ - JAX-native kernels for geometry, transforms, and residual assembly.
98
+ - Differentiable optimization workflows are available through the Python API and bundled examples.
99
+
100
+ ## Quickstart
101
+
102
+ Install and run the showcase:
103
+
104
+ ```bash
105
+ python -m venv .venv
106
+ source .venv/bin/activate
107
+ python -m pip install -e .
108
+ python examples/showcase_axisym_input_to_wout.py --suite
109
+ ```
110
+
111
+ If you want a release-style non-editable install instead, use:
112
+
113
+ ```bash
114
+ python -m pip install .
115
+ ```
116
+
117
+ If you want the bundled reference outputs and mgrid files, fetch the assets once:
118
+
119
+ ```bash
120
+ python tools/fetch_assets.py
121
+ ```
122
+
123
+ Lightweight clone (keeps full history, downloads blobs lazily):
124
+
125
+ ```bash
126
+ git clone --filter=blob:none https://github.com/uwplasma/vmec_jax
127
+ ```
128
+
129
+ Note: the repo history was rewritten on 2026-03-16 to remove large assets from
130
+ all commits. If you cloned before that date, please re-clone (or prune and
131
+ reset) to get the smaller history.
132
+
133
+ CLI (VMEC2000-style executable):
134
+
135
+ ```bash
136
+ vmec_jax examples/data/input.circular_tokamak
137
+ ```
138
+
139
+ Sanity check (verifies the console script is wired to the right interpreter):
140
+
141
+ ```bash
142
+ vmec_jax --help
143
+ ```
144
+
145
+ If the `vmec_jax` command is not found or raises `ModuleNotFoundError`, make sure
146
+ you installed with the same interpreter and use the module entrypoint:
147
+
148
+ ```bash
149
+ python -m pip install -e .
150
+ python -m vmec_jax examples/data/input.circular_tokamak
151
+ ```
152
+
153
+ For fixed-boundary inputs, the default CLI path now uses the optimized
154
+ controller: it tries the fast final-grid scan route first, then escalates to
155
+ staged continuation and strict parity finishing only when the input structure
156
+ and residual history require it. Pass `--parity` to force the conservative
157
+ VMEC2000 loop. Pass `--solver-mode accelerated` to request the optimized track
158
+ explicitly.
159
+
160
+ Python driver comparison (reference track vs optimized CLI-style track):
161
+
162
+ ```bash
163
+ python examples/fixed_boundary_driver_tracks.py \
164
+ examples/data/input.circular_tokamak \
165
+ --quiet --json
166
+ ```
167
+
168
+ Run tests:
169
+
170
+ ```bash
171
+ pytest -q
172
+ ```
173
+
174
+ Full test suite (requires netCDF assets):
175
+
176
+ ```bash
177
+ python tools/fetch_assets.py
178
+ RUN_FULL=1 pytest -q
179
+ ```
180
+
181
+ Advanced optimization examples live in `examples/optimization/`. They are
182
+ intended as deeper workflow templates rather than README quickstarts, so use
183
+ the fixed-boundary driver example above as the validated copy/paste entry point
184
+ and then adapt the optimization scripts for your target objective. The simplest
185
+ starting point is:
186
+
187
+ ```bash
188
+ python examples/optimization/target_iota_aspect_volume.py --opt-steps 2
189
+ ```
190
+
191
+ That example keeps the boundary parameterization small (`max |m|,|n| <= 1`),
192
+ targets equilibrium volume, aspect ratio, and mean iota, and defaults to the
193
+ bundled current-driven `cth_like_fixed_bdy` case so the iota channel is active.
194
+
195
+ ## Performance vs parity
196
+
197
+ - Default runs aim for VMEC2000-compatible behavior while selecting the fastest stable path for the input.
198
+ - Use `--parity` (or `performance_mode=False` in Python) to force the conservative VMEC2000 loop.
199
+ - Use `--solver-mode accelerated` to force the optimized fixed-boundary controller explicitly.
200
+
201
+ Details, profiling guidance, and parity methodology:
202
+
203
+ - `docs/performance.rst`
204
+ - `docs/validation.rst`
205
+ - `tools/diagnostics/parity_manifest.toml` + `tools/diagnostics/parity_sweep_manifest.py`
206
+
207
+ ## VMEC++ notes
208
+
209
+ The current runtime benchmark compares vmec_jax against VMEC2000. VMEC++ is not included in this benchmark.
210
+
211
+ When VMEC++ is available, it can be added to the runtime plot via `--cpu-summary` entries with `backend=vmecpp`. Some inputs are not supported or do not converge under the same single-grid settings:
212
+
213
+ VMEC++ unsupported inputs (`lasym=True`):
214
+
215
+ - `LandremanSenguptaPlunk_section5p3_low_res`
216
+ - `basic_non_stellsym_pressure`
217
+ - `cth_like_free_bdy_lasym_small`
218
+ - `up_down_asymmetric_tokamak`
219
+
220
+ VMEC++ known non-convergence on these `lasym=False` cases under the same single-grid settings:
221
+
222
+ - `DIII-D_lasym_false`
223
+ - `LandremanPaul2021_QA_reactorScale_lowres`
224
+ - `LandremanPaul2021_QH_reactorScale_lowres`
225
+ - `LandremanSengupta2019_section5.4_B2_A80`
226
+ - `cth_like_fixed_bdy`
227
+
228
+ ## CLI output and `NSTEP`
229
+
230
+ The VMEC-style iteration loop prints every `NSTEP` iterations. Larger `NSTEP` means fewer print callbacks and faster runs.
231
+
232
+ To disable live printing, set:
233
+
234
+ ```bash
235
+ export VMEC_JAX_SCAN_PRINT=0
236
+ ```
237
+
238
+ Quiet runs (`--quiet` or `verbose=False`) default the scan path to a minimal
239
+ history mode (only `fsqr/fsqz/fsql` and `w_history` are kept) to reduce
240
+ host/device traffic. You can override this with:
241
+
242
+ ```bash
243
+ export VMEC_JAX_SCAN_MINIMAL=0 # keep full scan diagnostics even when quiet
244
+ ```
@@ -0,0 +1,213 @@
1
+ # vmec-jax
2
+
3
+ End-to-end differentiable JAX implementation of **VMEC2000** for fixed-boundary
4
+ and free-boundary ideal-MHD equilibria.
5
+
6
+ ## Showcase (single-grid)
7
+
8
+ All figures below use the same **single-grid** run settings: `NS_ARRAY=151`, `NITER_ARRAY=5000`, `FTOL_ARRAY=1e-14`, `NSTEP=500`.
9
+
10
+ <table>
11
+ <tr>
12
+ <td><img src="docs/_static/figures/axisym_compare_cross_sections.png" width="420" /></td>
13
+ <td><img src="docs/_static/figures/qa_compare_cross_sections.png" width="420" /></td>
14
+ </tr>
15
+ <tr>
16
+ <td align="center"><code>ITERModel</code> cross-section (VMEC2000 vs vmec_jax)</td>
17
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> cross-section (VMEC2000 vs vmec_jax)</td>
18
+ </tr>
19
+ <tr>
20
+ <td><img src="docs/_static/figures/axisym_compare_iota.png" width="420" /></td>
21
+ <td><img src="docs/_static/figures/qa_compare_iota.png" width="420" /></td>
22
+ </tr>
23
+ <tr>
24
+ <td align="center"><code>ITERModel</code> iota (VMEC2000 vs vmec_jax)</td>
25
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> iota (VMEC2000 vs vmec_jax)</td>
26
+ </tr>
27
+ </table>
28
+
29
+ <p align="center">
30
+ <img src="docs/_static/figures/readme_fsq_trace_single_grid.png" width="860" />
31
+ </p>
32
+
33
+ <p align="center">
34
+ <img src="docs/_static/figures/readme_runtime_compare.png" width="860" />
35
+ </p>
36
+
37
+ **Cold vs warm runtime**: the *cold* bar includes XLA JIT compilation on the first call (one-time cost per process); the *warm* bar is the steady-state solve time for all subsequent calls in the same process, with the compiled kernels already in-memory. VMEC2000 is a pre-compiled Fortran binary and therefore has no compilation overhead — it is always effectively "cold". The warm vmec_jax time is the fair comparison for repeated solves (e.g., in an optimization loop). Starting from v0.2, vmec_jax automatically caches compiled XLA kernels to disk (`~/.cache/vmec_jax/jax_cache`), so that *cold* runs in a fresh process on the same machine after the first invocation benefit from the on-disk cache and approach warm-run speed.
38
+
39
+ ## More visuals (single-grid)
40
+
41
+ <table>
42
+ <tr>
43
+ <td><img src="docs/_static/figures/axisym_compare_3d.png" width="420" /></td>
44
+ <td><img src="docs/_static/figures/qa_compare_3d.png" width="420" /></td>
45
+ </tr>
46
+ <tr>
47
+ <td align="center"><code>ITERModel</code> 3D LCFS (VMEC2000 vs vmec_jax)</td>
48
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> 3D LCFS (VMEC2000 vs vmec_jax)</td>
49
+ </tr>
50
+ <tr>
51
+ <td><img src="docs/_static/figures/axisym_compare_bmag_surface.png" width="420" /></td>
52
+ <td><img src="docs/_static/figures/qa_compare_bmag_surface.png" width="420" /></td>
53
+ </tr>
54
+ <tr>
55
+ <td align="center"><code>ITERModel</code> |B| on LCFS (VMEC2000 vs vmec_jax)</td>
56
+ <td align="center"><code>LandremanPaul2021_QA_lowres</code> |B| on LCFS (VMEC2000 vs vmec_jax)</td>
57
+ </tr>
58
+ </table>
59
+
60
+ ## What it is
61
+
62
+ - VMEC2000-parity solver for fixed-boundary and free-boundary equilibria.
63
+ - Supports axisymmetric and non-axisymmetric configurations, with `lasym=False` and `lasym=True` for stellarator symmetry/asymmetry and up-down symmetry/asymmetry.
64
+ - Default CLI path is `vmec_jax input.name`.
65
+ - `wout_*.nc` outputs, iteration diagnostics, and manifest-based parity sweeps are built around VMEC2000-compatible workflows.
66
+ - JAX-native kernels for geometry, transforms, and residual assembly.
67
+ - Differentiable optimization workflows are available through the Python API and bundled examples.
68
+
69
+ ## Quickstart
70
+
71
+ Install and run the showcase:
72
+
73
+ ```bash
74
+ python -m venv .venv
75
+ source .venv/bin/activate
76
+ python -m pip install -e .
77
+ python examples/showcase_axisym_input_to_wout.py --suite
78
+ ```
79
+
80
+ If you want a release-style non-editable install instead, use:
81
+
82
+ ```bash
83
+ python -m pip install .
84
+ ```
85
+
86
+ If you want the bundled reference outputs and mgrid files, fetch the assets once:
87
+
88
+ ```bash
89
+ python tools/fetch_assets.py
90
+ ```
91
+
92
+ Lightweight clone (keeps full history, downloads blobs lazily):
93
+
94
+ ```bash
95
+ git clone --filter=blob:none https://github.com/uwplasma/vmec_jax
96
+ ```
97
+
98
+ Note: the repo history was rewritten on 2026-03-16 to remove large assets from
99
+ all commits. If you cloned before that date, please re-clone (or prune and
100
+ reset) to get the smaller history.
101
+
102
+ CLI (VMEC2000-style executable):
103
+
104
+ ```bash
105
+ vmec_jax examples/data/input.circular_tokamak
106
+ ```
107
+
108
+ Sanity check (verifies the console script is wired to the right interpreter):
109
+
110
+ ```bash
111
+ vmec_jax --help
112
+ ```
113
+
114
+ If the `vmec_jax` command is not found or raises `ModuleNotFoundError`, make sure
115
+ you installed with the same interpreter and use the module entrypoint:
116
+
117
+ ```bash
118
+ python -m pip install -e .
119
+ python -m vmec_jax examples/data/input.circular_tokamak
120
+ ```
121
+
122
+ For fixed-boundary inputs, the default CLI path now uses the optimized
123
+ controller: it tries the fast final-grid scan route first, then escalates to
124
+ staged continuation and strict parity finishing only when the input structure
125
+ and residual history require it. Pass `--parity` to force the conservative
126
+ VMEC2000 loop. Pass `--solver-mode accelerated` to request the optimized track
127
+ explicitly.
128
+
129
+ Python driver comparison (reference track vs optimized CLI-style track):
130
+
131
+ ```bash
132
+ python examples/fixed_boundary_driver_tracks.py \
133
+ examples/data/input.circular_tokamak \
134
+ --quiet --json
135
+ ```
136
+
137
+ Run tests:
138
+
139
+ ```bash
140
+ pytest -q
141
+ ```
142
+
143
+ Full test suite (requires netCDF assets):
144
+
145
+ ```bash
146
+ python tools/fetch_assets.py
147
+ RUN_FULL=1 pytest -q
148
+ ```
149
+
150
+ Advanced optimization examples live in `examples/optimization/`. They are
151
+ intended as deeper workflow templates rather than README quickstarts, so use
152
+ the fixed-boundary driver example above as the validated copy/paste entry point
153
+ and then adapt the optimization scripts for your target objective. The simplest
154
+ starting point is:
155
+
156
+ ```bash
157
+ python examples/optimization/target_iota_aspect_volume.py --opt-steps 2
158
+ ```
159
+
160
+ That example keeps the boundary parameterization small (`max |m|,|n| <= 1`),
161
+ targets equilibrium volume, aspect ratio, and mean iota, and defaults to the
162
+ bundled current-driven `cth_like_fixed_bdy` case so the iota channel is active.
163
+
164
+ ## Performance vs parity
165
+
166
+ - Default runs aim for VMEC2000-compatible behavior while selecting the fastest stable path for the input.
167
+ - Use `--parity` (or `performance_mode=False` in Python) to force the conservative VMEC2000 loop.
168
+ - Use `--solver-mode accelerated` to force the optimized fixed-boundary controller explicitly.
169
+
170
+ Details, profiling guidance, and parity methodology:
171
+
172
+ - `docs/performance.rst`
173
+ - `docs/validation.rst`
174
+ - `tools/diagnostics/parity_manifest.toml` + `tools/diagnostics/parity_sweep_manifest.py`
175
+
176
+ ## VMEC++ notes
177
+
178
+ The current runtime benchmark compares vmec_jax against VMEC2000. VMEC++ is not included in this benchmark.
179
+
180
+ When VMEC++ is available, it can be added to the runtime plot via `--cpu-summary` entries with `backend=vmecpp`. Some inputs are not supported or do not converge under the same single-grid settings:
181
+
182
+ VMEC++ unsupported inputs (`lasym=True`):
183
+
184
+ - `LandremanSenguptaPlunk_section5p3_low_res`
185
+ - `basic_non_stellsym_pressure`
186
+ - `cth_like_free_bdy_lasym_small`
187
+ - `up_down_asymmetric_tokamak`
188
+
189
+ VMEC++ known non-convergence on these `lasym=False` cases under the same single-grid settings:
190
+
191
+ - `DIII-D_lasym_false`
192
+ - `LandremanPaul2021_QA_reactorScale_lowres`
193
+ - `LandremanPaul2021_QH_reactorScale_lowres`
194
+ - `LandremanSengupta2019_section5.4_B2_A80`
195
+ - `cth_like_fixed_bdy`
196
+
197
+ ## CLI output and `NSTEP`
198
+
199
+ The VMEC-style iteration loop prints every `NSTEP` iterations. Larger `NSTEP` means fewer print callbacks and faster runs.
200
+
201
+ To disable live printing, set:
202
+
203
+ ```bash
204
+ export VMEC_JAX_SCAN_PRINT=0
205
+ ```
206
+
207
+ Quiet runs (`--quiet` or `verbose=False`) default the scan path to a minimal
208
+ history mode (only `fsqr/fsqz/fsql` and `w_history` are kept) to reduce
209
+ host/device traffic. You can override this with:
210
+
211
+ ```bash
212
+ export VMEC_JAX_SCAN_MINIMAL=0 # keep full scan diagnostics even when quiet
213
+ ```
@@ -0,0 +1,108 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "vmec-jax"
7
+ version = "0.0.1"
8
+ description = "End-to-end differentiable JAX implementation of VMEC2000 for fixed and free-boundary equilibria."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [{name = "vmec_jax contributors"}]
12
+ license = "MIT"
13
+
14
+ dependencies = [
15
+ "numpy",
16
+ "jax",
17
+ "jaxlib",
18
+ "netCDF4",
19
+ "tomli; python_version < '3.11'",
20
+ ]
21
+
22
+ [project.scripts]
23
+ vmec_jax = "vmec_jax.cli:main"
24
+ xvmec_jax = "vmec_jax.cli:main"
25
+
26
+ [project.optional-dependencies]
27
+ jax = ["jax", "jaxlib"]
28
+ netcdf = ["netCDF4"]
29
+ docs = [
30
+ "sphinx",
31
+ "furo",
32
+ ]
33
+ plots = ["matplotlib"]
34
+ dev = ["pytest", "ruff", "mypy", "types-setuptools"]
35
+
36
+ [tool.setuptools.packages.find]
37
+ where = ["."]
38
+ exclude = ["tests*", "docs*", "examples*", "tools*"]
39
+
40
+ [tool.pytest.ini_options]
41
+ markers = [
42
+ "vmec2000: integration tests requiring a local VMEC2000 build (+ mpi4py/netCDF4)",
43
+ "full: tests that require large netCDF assets (run tools/fetch_assets.py)",
44
+ ]
45
+
46
+ [tool.ruff]
47
+ line-length = 120
48
+
49
+ [tool.ruff.lint]
50
+ # E402: module-level import not at top-of-file. This codebase uses deferred
51
+ # imports inside functions deliberately to avoid triggering JAX/GPU
52
+ # initialisation at import time. Silencing E402 globally is intentional.
53
+ # F821: undefined name in annotation. Ruff occasionally reports false positives
54
+ # inside very long functions (closures over outer-scope locals) and for
55
+ # TYPE_CHECKING-only imports used in annotations. Silence globally and
56
+ # rely on mypy for real undefined-name checks.
57
+ # F841: local variable is assigned to but never used. Many VMEC routines assign
58
+ # intermediate diagnostics for debugging that are not always consumed by
59
+ # the caller. mypy --warn-unused-variables or a per-symbol noqa is a
60
+ # better gate here.
61
+ # E501: line too long. Handled by the line-length setting above; long lines in
62
+ # generated / ported Fortran-style code are acceptable.
63
+ ignore = ["E402", "F821", "F841", "E501"]
64
+
65
+ # Allow single-statement blocks on the same line as the colon (E701) and
66
+ # semicolons as statement separators (E702) in tightly coupled numerical kernels.
67
+ # These are common in ported Fortran code and do not impair readability.
68
+ # E731: lambda-assignment — allow where lambdas are used as concise callbacks.
69
+ # E741: ambiguous variable name — l/O/I are used as loop indices in physics code.
70
+ extend-ignore = ["E701", "E702", "E731", "E741"]
71
+
72
+ [tool.mypy]
73
+ # vmec_jax is a JAX/NumPy numerical physics codebase ported from Fortran.
74
+ # Many type errors are pre-existing false positives stemming from:
75
+ # - JAX's dynamically-typed array operations (jax.Array is opaque to mypy)
76
+ # - Very long functions with nested closures (mypy loses type context)
77
+ # - TYPE_CHECKING-guarded imports used in annotations
78
+ # - Ported Fortran patterns (lots of Any-typed intermediate arrays)
79
+ # We configure mypy conservatively: catch obvious bugs but avoid blocking
80
+ # development with spurious false-positive noise.
81
+ python_version = "3.10"
82
+ ignore_missing_imports = true
83
+ no_strict_optional = true
84
+ warn_return_any = false
85
+ warn_unused_ignores = false
86
+ # Disable the most noisy checks that produce mostly false positives here:
87
+ disable_error_code = ["attr-defined", "operator", "valid-type", "name-defined", "call-arg", "return-value", "arg-type", "assignment", "index", "union-attr", "misc", "type-arg", "no-any-return", "override"]
88
+
89
+ [[tool.mypy.overrides]]
90
+ module = [
91
+ "vmec_jax.solve",
92
+ "vmec_jax.driver",
93
+ "vmec_jax.boundary",
94
+ "vmec_jax.vmec_forces",
95
+ "vmec_jax.geom",
96
+ "vmec_jax.wout",
97
+ "vmec_jax.vmec_parity",
98
+ "vmec_jax.fourier",
99
+ "vmec_jax.static",
100
+ "vmec_jax.implicit",
101
+ "vmec_jax.diagnostics",
102
+ "vmec_jax.free_boundary",
103
+ "vmec_jax.init_guess",
104
+ "vmec_jax.booz_input",
105
+ "vmec_jax.vmec_residue",
106
+ "vmec_jax.vmec_tomnsp",
107
+ ]
108
+ ignore_errors = true
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from vmec_jax.vmec_bcovar import _apply_vmec_lambda_axis_closure
6
+
7
+
8
+ def test_lambda_axis_closure_copies_m0_npos_modes():
9
+ lsin = np.array(
10
+ [
11
+ [1.0, 2.0, 3.0, 4.0],
12
+ [10.0, 20.0, 30.0, 40.0],
13
+ [100.0, 200.0, 300.0, 400.0],
14
+ ],
15
+ dtype=float,
16
+ )
17
+ m_modes = np.array([0, 0, 1, 0], dtype=int)
18
+ n_modes = np.array([0, 1, 1, 2], dtype=int)
19
+
20
+ out = np.asarray(
21
+ _apply_vmec_lambda_axis_closure(
22
+ Lsin=lsin,
23
+ m_modes=m_modes,
24
+ n_modes=n_modes,
25
+ lthreed=True,
26
+ ntor=2,
27
+ )
28
+ )
29
+
30
+ np.testing.assert_allclose(out[0, 0], lsin[0, 0]) # n=0 unchanged
31
+ np.testing.assert_allclose(out[0, 1], lsin[1, 1]) # m=0,n>0 copied
32
+ np.testing.assert_allclose(out[0, 2], lsin[0, 2]) # m!=0 unchanged
33
+ np.testing.assert_allclose(out[0, 3], lsin[1, 3]) # m=0,n>0 copied
34
+ np.testing.assert_allclose(out[1:], lsin[1:]) # interior unchanged
35
+
36
+
37
+ def test_lambda_axis_closure_disabled_for_axisymmetric_or_ntor_zero():
38
+ lsin = np.array([[1.0, 2.0], [10.0, 20.0]], dtype=float)
39
+ m_modes = np.array([0, 0], dtype=int)
40
+ n_modes = np.array([0, 1], dtype=int)
41
+
42
+ out_axis = np.asarray(
43
+ _apply_vmec_lambda_axis_closure(
44
+ Lsin=lsin,
45
+ m_modes=m_modes,
46
+ n_modes=n_modes,
47
+ lthreed=False,
48
+ ntor=1,
49
+ )
50
+ )
51
+ out_ntor0 = np.asarray(
52
+ _apply_vmec_lambda_axis_closure(
53
+ Lsin=lsin,
54
+ m_modes=m_modes,
55
+ n_modes=n_modes,
56
+ lthreed=True,
57
+ ntor=0,
58
+ )
59
+ )
60
+ np.testing.assert_allclose(out_axis, lsin)
61
+ np.testing.assert_allclose(out_ntor0, lsin)
62
+
63
+
64
+ def test_lambda_axis_closure_noop_when_ns_one():
65
+ lsin = np.array([[1.0, 2.0, 3.0]], dtype=float)
66
+ m_modes = np.array([0, 0, 1], dtype=int)
67
+ n_modes = np.array([0, 2, 1], dtype=int)
68
+
69
+ out = np.asarray(
70
+ _apply_vmec_lambda_axis_closure(
71
+ Lsin=lsin,
72
+ m_modes=m_modes,
73
+ n_modes=n_modes,
74
+ lthreed=True,
75
+ ntor=2,
76
+ )
77
+ )
78
+ np.testing.assert_allclose(out, lsin)