xpcsjax 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. xpcsjax/__init__.py +166 -0
  2. xpcsjax/cli/__init__.py +78 -0
  3. xpcsjax/cli/args_parser.py +442 -0
  4. xpcsjax/cli/commands.py +181 -0
  5. xpcsjax/cli/config_generator.py +218 -0
  6. xpcsjax/cli/config_handling.py +327 -0
  7. xpcsjax/cli/config_template.py +375 -0
  8. xpcsjax/cli/data_pipeline.py +49 -0
  9. xpcsjax/cli/main.py +195 -0
  10. xpcsjax/cli/optimization_runner.py +329 -0
  11. xpcsjax/cli/plot_backend.py +82 -0
  12. xpcsjax/cli/plot_dispatch.py +181 -0
  13. xpcsjax/cli/plot_families/__init__.py +1 -0
  14. xpcsjax/cli/plot_families/experimental.py +69 -0
  15. xpcsjax/cli/plot_families/postfit.py +176 -0
  16. xpcsjax/cli/plot_families/simulated.py +276 -0
  17. xpcsjax/cli/result_saving.py +19 -0
  18. xpcsjax/cli/xla_config.py +192 -0
  19. xpcsjax/config/__init__.py +94 -0
  20. xpcsjax/config/heterodyne_parameter_manager.py +676 -0
  21. xpcsjax/config/heterodyne_parameter_names.py +118 -0
  22. xpcsjax/config/heterodyne_parameter_space.py +561 -0
  23. xpcsjax/config/heterodyne_physics_validators.py +577 -0
  24. xpcsjax/config/manager.py +1428 -0
  25. xpcsjax/config/parameter_manager.py +824 -0
  26. xpcsjax/config/parameter_names.py +319 -0
  27. xpcsjax/config/parameter_registry.py +924 -0
  28. xpcsjax/config/parameter_space.py +479 -0
  29. xpcsjax/config/physics_validators.py +394 -0
  30. xpcsjax/config/templates/xpcsjax_laminar_flow.yaml +875 -0
  31. xpcsjax/config/templates/xpcsjax_static_anisotropic.yaml +1179 -0
  32. xpcsjax/config/templates/xpcsjax_static_isotropic.yaml +662 -0
  33. xpcsjax/config/templates/xpcsjax_two_component.yaml +982 -0
  34. xpcsjax/config/types.py +426 -0
  35. xpcsjax/core/__init__.py +21 -0
  36. xpcsjax/core/diagonal_correction.py +562 -0
  37. xpcsjax/core/fitting.py +129 -0
  38. xpcsjax/core/heterodyne_jax_backend.py +653 -0
  39. xpcsjax/core/heterodyne_model.py +327 -0
  40. xpcsjax/core/heterodyne_model_stateful.py +471 -0
  41. xpcsjax/core/heterodyne_models.py +489 -0
  42. xpcsjax/core/heterodyne_physics_factors.py +230 -0
  43. xpcsjax/core/heterodyne_physics_kernel.py +442 -0
  44. xpcsjax/core/heterodyne_physics_utils.py +419 -0
  45. xpcsjax/core/heterodyne_scaling_utils.py +761 -0
  46. xpcsjax/core/homodyne_model.py +379 -0
  47. xpcsjax/core/jax_backend.py +1846 -0
  48. xpcsjax/core/math_primitives.py +57 -0
  49. xpcsjax/core/model_mixins.py +519 -0
  50. xpcsjax/core/models.py +716 -0
  51. xpcsjax/core/physics.py +598 -0
  52. xpcsjax/core/physics_factors.py +369 -0
  53. xpcsjax/core/physics_nlsq.py +478 -0
  54. xpcsjax/core/physics_utils.py +356 -0
  55. xpcsjax/data/__init__.py +304 -0
  56. xpcsjax/data/angle_filtering.py +417 -0
  57. xpcsjax/data/config.py +780 -0
  58. xpcsjax/data/dataset.py +82 -0
  59. xpcsjax/data/filtering_utils.py +664 -0
  60. xpcsjax/data/memory_manager.py +1593 -0
  61. xpcsjax/data/optimization.py +1152 -0
  62. xpcsjax/data/performance_engine.py +1942 -0
  63. xpcsjax/data/phi_filtering.py +442 -0
  64. xpcsjax/data/preprocessing.py +1200 -0
  65. xpcsjax/data/quality_controller.py +2094 -0
  66. xpcsjax/data/types.py +48 -0
  67. xpcsjax/data/validation.py +1287 -0
  68. xpcsjax/data/validators.py +328 -0
  69. xpcsjax/data/xpcs_loader.py +2741 -0
  70. xpcsjax/device/__init__.py +276 -0
  71. xpcsjax/device/config.py +265 -0
  72. xpcsjax/device/cpu.py +588 -0
  73. xpcsjax/gui/__init__.py +8 -0
  74. xpcsjax/gui/app.py +109 -0
  75. xpcsjax/gui/controllers/__init__.py +3 -0
  76. xpcsjax/gui/controllers/fit_queue.py +252 -0
  77. xpcsjax/gui/data_inspect.py +117 -0
  78. xpcsjax/gui/error_presenter.py +47 -0
  79. xpcsjax/gui/export.py +80 -0
  80. xpcsjax/gui/ipc/__init__.py +3 -0
  81. xpcsjax/gui/ipc/diagnostics.py +59 -0
  82. xpcsjax/gui/ipc/emitter.py +53 -0
  83. xpcsjax/gui/ipc/handle.py +226 -0
  84. xpcsjax/gui/ipc/job.py +26 -0
  85. xpcsjax/gui/ipc/log_capture.py +52 -0
  86. xpcsjax/gui/ipc/worker.py +87 -0
  87. xpcsjax/gui/project/__init__.py +3 -0
  88. xpcsjax/gui/project/model.py +105 -0
  89. xpcsjax/gui/project/persist.py +111 -0
  90. xpcsjax/gui/project/tree_model.py +83 -0
  91. xpcsjax/gui/result_loader.py +91 -0
  92. xpcsjax/gui/theme.py +438 -0
  93. xpcsjax/gui/views/__init__.py +3 -0
  94. xpcsjax/gui/views/config_dialogs.py +302 -0
  95. xpcsjax/gui/views/error_dialog.py +41 -0
  96. xpcsjax/gui/views/inspector.py +122 -0
  97. xpcsjax/gui/views/main_window.py +592 -0
  98. xpcsjax/gui/views/main_window_support/__init__.py +1 -0
  99. xpcsjax/gui/views/main_window_support/project_dialog_handler.py +148 -0
  100. xpcsjax/gui/views/main_window_support/result_presenter.py +110 -0
  101. xpcsjax/gui/views/main_window_support/run_controller.py +111 -0
  102. xpcsjax/gui/views/main_window_support/status_manager.py +54 -0
  103. xpcsjax/gui/views/plots/__init__.py +1 -0
  104. xpcsjax/gui/views/plots/grid.py +245 -0
  105. xpcsjax/gui/views/plots/helpers.py +92 -0
  106. xpcsjax/gui/views/plots/maps.py +157 -0
  107. xpcsjax/gui/views/plots/residuals.py +102 -0
  108. xpcsjax/gui/views/plots/squares.py +58 -0
  109. xpcsjax/gui/views/plots_view.py +31 -0
  110. xpcsjax/gui/views/project_panel.py +97 -0
  111. xpcsjax/gui/views/raster.py +45 -0
  112. xpcsjax/gui/viz_bundle.py +62 -0
  113. xpcsjax/io/__init__.py +20 -0
  114. xpcsjax/io/json_utils.py +144 -0
  115. xpcsjax/io/nlsq_writers.py +232 -0
  116. xpcsjax/optimization/__init__.py +129 -0
  117. xpcsjax/optimization/batch_statistics.py +186 -0
  118. xpcsjax/optimization/exceptions.py +262 -0
  119. xpcsjax/optimization/nlsq/__init__.py +1178 -0
  120. xpcsjax/optimization/nlsq/adapter.py +1439 -0
  121. xpcsjax/optimization/nlsq/adapter_base.py +364 -0
  122. xpcsjax/optimization/nlsq/adaptive_regularization.py +527 -0
  123. xpcsjax/optimization/nlsq/anti_degeneracy_controller.py +1054 -0
  124. xpcsjax/optimization/nlsq/anti_degeneracy_diagnostics.py +62 -0
  125. xpcsjax/optimization/nlsq/anti_degeneracy_logging.py +117 -0
  126. xpcsjax/optimization/nlsq/cmaes_wrapper.py +1246 -0
  127. xpcsjax/optimization/nlsq/config.py +1179 -0
  128. xpcsjax/optimization/nlsq/core.py +2684 -0
  129. xpcsjax/optimization/nlsq/data_prep.py +377 -0
  130. xpcsjax/optimization/nlsq/fallback_chain.py +455 -0
  131. xpcsjax/optimization/nlsq/fit_computation.py +539 -0
  132. xpcsjax/optimization/nlsq/gradient_diagnostics.py +267 -0
  133. xpcsjax/optimization/nlsq/gradient_monitor.py +634 -0
  134. xpcsjax/optimization/nlsq/heterodyne_adapter.py +935 -0
  135. xpcsjax/optimization/nlsq/heterodyne_adapter_base.py +80 -0
  136. xpcsjax/optimization/nlsq/heterodyne_config.py +1314 -0
  137. xpcsjax/optimization/nlsq/heterodyne_constant_mode.py +518 -0
  138. xpcsjax/optimization/nlsq/heterodyne_core.py +4606 -0
  139. xpcsjax/optimization/nlsq/heterodyne_data_prep.py +349 -0
  140. xpcsjax/optimization/nlsq/heterodyne_engine_route.py +745 -0
  141. xpcsjax/optimization/nlsq/heterodyne_logging.py +347 -0
  142. xpcsjax/optimization/nlsq/heterodyne_memory.py +340 -0
  143. xpcsjax/optimization/nlsq/heterodyne_multistart.py +142 -0
  144. xpcsjax/optimization/nlsq/heterodyne_result_builder.py +747 -0
  145. xpcsjax/optimization/nlsq/heterodyne_results.py +217 -0
  146. xpcsjax/optimization/nlsq/heterodyne_stratified_data.py +229 -0
  147. xpcsjax/optimization/nlsq/heterodyne_stratified_ls.py +1292 -0
  148. xpcsjax/optimization/nlsq/heterodyne_views.py +109 -0
  149. xpcsjax/optimization/nlsq/hierarchical.py +733 -0
  150. xpcsjax/optimization/nlsq/jacobian.py +244 -0
  151. xpcsjax/optimization/nlsq/memory.py +516 -0
  152. xpcsjax/optimization/nlsq/model_adapter.py +231 -0
  153. xpcsjax/optimization/nlsq/multistart.py +1482 -0
  154. xpcsjax/optimization/nlsq/parallel_accumulator.py +773 -0
  155. xpcsjax/optimization/nlsq/parameter_index_mapper.py +318 -0
  156. xpcsjax/optimization/nlsq/parameter_utils.py +587 -0
  157. xpcsjax/optimization/nlsq/per_angle_mode.py +392 -0
  158. xpcsjax/optimization/nlsq/progress.py +534 -0
  159. xpcsjax/optimization/nlsq/recovery.py +529 -0
  160. xpcsjax/optimization/nlsq/result_builder.py +453 -0
  161. xpcsjax/optimization/nlsq/results.py +326 -0
  162. xpcsjax/optimization/nlsq/shear_weighting.py +466 -0
  163. xpcsjax/optimization/nlsq/strategies/__init__.py +65 -0
  164. xpcsjax/optimization/nlsq/strategies/chunking.py +1233 -0
  165. xpcsjax/optimization/nlsq/strategies/executors.py +489 -0
  166. xpcsjax/optimization/nlsq/strategies/heterodyne_hybrid_streaming.py +995 -0
  167. xpcsjax/optimization/nlsq/strategies/hybrid_streaming.py +2039 -0
  168. xpcsjax/optimization/nlsq/strategies/out_of_core.py +604 -0
  169. xpcsjax/optimization/nlsq/strategies/residual.py +848 -0
  170. xpcsjax/optimization/nlsq/strategies/residual_jit.py +644 -0
  171. xpcsjax/optimization/nlsq/strategies/sequential.py +1010 -0
  172. xpcsjax/optimization/nlsq/strategies/stratified_ls.py +1015 -0
  173. xpcsjax/optimization/nlsq/transforms.py +491 -0
  174. xpcsjax/optimization/nlsq/validation.py +777 -0
  175. xpcsjax/optimization/nlsq/wrapper.py +4174 -0
  176. xpcsjax/optimization/numerical_validation.py +251 -0
  177. xpcsjax/optimization/recovery_strategies.py +209 -0
  178. xpcsjax/post_install.py +952 -0
  179. xpcsjax/runtime/__init__.py +34 -0
  180. xpcsjax/runtime/shell/__init__.py +49 -0
  181. xpcsjax/runtime/shell/activation/__init__.py +1 -0
  182. xpcsjax/runtime/shell/activation/xla_config.bash +121 -0
  183. xpcsjax/runtime/shell/activation/xla_config.fish +116 -0
  184. xpcsjax/runtime/shell/completion.sh +228 -0
  185. xpcsjax/runtime/shell/completion_spec.py +99 -0
  186. xpcsjax/runtime/shell/generate_completion.py +246 -0
  187. xpcsjax/runtime/utils/__init__.py +17 -0
  188. xpcsjax/runtime/utils/system_validator.py +702 -0
  189. xpcsjax/service/__init__.py +20 -0
  190. xpcsjax/service/config.py +209 -0
  191. xpcsjax/service/data.py +226 -0
  192. xpcsjax/service/events.py +107 -0
  193. xpcsjax/service/fit.py +196 -0
  194. xpcsjax/service/persist.py +412 -0
  195. xpcsjax/service/plots.py +100 -0
  196. xpcsjax/uninstall_scripts.py +634 -0
  197. xpcsjax/utils/__init__.py +34 -0
  198. xpcsjax/utils/async_io.py +337 -0
  199. xpcsjax/utils/logging.py +1755 -0
  200. xpcsjax/utils/path_validation.py +307 -0
  201. xpcsjax/viz/__init__.py +62 -0
  202. xpcsjax/viz/datashader_backend.py +374 -0
  203. xpcsjax/viz/diagnostics.py +139 -0
  204. xpcsjax/viz/nlsq_plots.py +1827 -0
  205. xpcsjax-0.1.0.data/data/share/xpcsjax/templates/xpcsjax_laminar_flow.yaml +875 -0
  206. xpcsjax-0.1.0.data/data/share/xpcsjax/templates/xpcsjax_static_anisotropic.yaml +1179 -0
  207. xpcsjax-0.1.0.data/data/share/xpcsjax/templates/xpcsjax_static_isotropic.yaml +662 -0
  208. xpcsjax-0.1.0.data/data/share/xpcsjax/templates/xpcsjax_two_component.yaml +982 -0
  209. xpcsjax-0.1.0.dist-info/METADATA +373 -0
  210. xpcsjax-0.1.0.dist-info/RECORD +213 -0
  211. xpcsjax-0.1.0.dist-info/WHEEL +4 -0
  212. xpcsjax-0.1.0.dist-info/entry_points.txt +17 -0
  213. xpcsjax-0.1.0.dist-info/licenses/LICENSE +21 -0
xpcsjax/__init__.py ADDED
@@ -0,0 +1,166 @@
1
+ """xpcsjax — unified JAX-native XPCS NLSQ fitting.
2
+
3
+ Public API (lazy-loaded — heavy deps like JAX import on first use):
4
+
5
+ from xpcsjax import load_xpcs_data, fit_nlsq, ConfigManager
6
+
7
+ data = load_xpcs_data("config.yaml")
8
+ result = fit_nlsq(data, "config.yaml")
9
+ print(result.parameters)
10
+ result.save("output/")
11
+
12
+ Env setup at import time is mirrored verbatim from homodyne/__init__.py.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ # ============================================================================
18
+ # Standard library imports
19
+ # ============================================================================
20
+ import importlib
21
+ import logging
22
+ import os
23
+
24
+ # ============================================================================
25
+ # JAX CPU Device Configuration (MUST be set before JAX import)
26
+ # ============================================================================
27
+ # Mirrored from homodyne/__init__.py, with one xpcsjax-specific adaptation
28
+ # (concurrency gating, see _xla_host_device_count):
29
+ # - xla_force_host_platform_device_count: enables parallel evaluation paths
30
+ # - xla_disable_hlo_passes=constant_folding: prevents > 1 s slow-compilation
31
+ # warnings on HYBRID_STREAMING strategy (23M+ points) where data arrays
32
+ # are captured in JIT closures. Performance impact: minimal (< 5 ms/call).
33
+
34
+
35
+ def _detect_worker_count() -> int:
36
+ """Concurrent fit-process count from the pool / pytest-xdist env-vars (>=1).
37
+
38
+ Inlined here (not imported from ``optimization.nlsq.memory``) on purpose:
39
+ this runs *before* the first JAX import, and importing that module could pull
40
+ JAX in early and defeat the env-before-import ordering this block exists for.
41
+ """
42
+ for _env_var in ("XPCSJAX_FIT_CONCURRENCY", "PYTEST_XDIST_WORKER_COUNT"):
43
+ _raw = os.environ.get(_env_var)
44
+ if _raw:
45
+ try:
46
+ return max(1, int(_raw))
47
+ except ValueError:
48
+ pass
49
+ return 1
50
+
51
+
52
+ def _xla_host_device_count(worker_count: int) -> int:
53
+ """Return the number of host CPU devices to force for XLA.
54
+
55
+ A lone fit benefits from 4 host devices (parallel evaluation paths). But
56
+ under parallelism (pytest-xdist or the production multistart / accumulator
57
+ pools) each of N worker processes forcing 4 devices means 4*N logical devices
58
+ contending for the physical cores — wasted per-worker compile/buffer overhead
59
+ and RAM. So drop to a single device per process whenever more than one fit
60
+ runs concurrently.
61
+ """
62
+ return 4 if worker_count <= 1 else 1
63
+
64
+
65
+ _WORKER_COUNT = _detect_worker_count()
66
+ _DEFAULT_XLA_FLAGS = [
67
+ f"--xla_force_host_platform_device_count={_xla_host_device_count(_WORKER_COUNT)}",
68
+ "--xla_disable_hlo_passes=constant_folding",
69
+ ]
70
+
71
+ # JAX must be in float64 for parameters spanning 6+ orders of magnitude.
72
+ # This env var MUST be set BEFORE the first JAX import.
73
+ os.environ.setdefault("JAX_ENABLE_X64", "1")
74
+
75
+ if "XLA_FLAGS" not in os.environ:
76
+ os.environ["XLA_FLAGS"] = " ".join(_DEFAULT_XLA_FLAGS)
77
+ else:
78
+ existing = os.environ["XLA_FLAGS"]
79
+ flags_to_add = []
80
+ for flag in _DEFAULT_XLA_FLAGS:
81
+ flag_name = flag.split("=")[0]
82
+ if flag_name not in existing:
83
+ flags_to_add.append(flag)
84
+ if flags_to_add:
85
+ os.environ["XLA_FLAGS"] += " " + " ".join(flags_to_add)
86
+
87
+ # Pin JAX to CPU (CPU-only; no GPU support). Setting this at
88
+ # package import time (before any jax import) is the *only* place this works
89
+ # reliably — spawn-pool worker init runs *after* the worker's `import jax`,
90
+ # which is why xpcsjax.viz.nlsq_plots.{_worker_init_cpu_only,_render_one_angle_worker}
91
+ # can't set this themselves. Child processes inherit os.environ from the parent.
92
+ os.environ.setdefault("JAX_PLATFORMS", "cpu")
93
+ os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
94
+ os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform")
95
+
96
+ # Cap each worker's JAX arena under parallelism so one fit process can't claim
97
+ # the whole device. No-op on the CPU backend (the v0.1 target), but a correct
98
+ # per-worker bound for any future GPU backend; setdefault leaves an explicit
99
+ # user/operator override untouched.
100
+ if _WORKER_COUNT > 1:
101
+ os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", f"{max(0.05, 0.9 / _WORKER_COUNT):.4f}")
102
+
103
+ # Suppress NLSQ GPU warnings (CPU-only; no GPU support)
104
+ os.environ.setdefault("NLSQ_SKIP_GPU_CHECK", "1")
105
+
106
+ # Suppress JAX backend logs (GPU fallback warnings on CPU-only systems)
107
+ logging.getLogger("jax._src.xla_bridge").setLevel(logging.ERROR)
108
+ logging.getLogger("jax._src.compiler").setLevel(logging.ERROR)
109
+
110
+ # ============================================================================
111
+ # Version
112
+ # ============================================================================
113
+ __version__ = "0.1.0"
114
+
115
+ # ============================================================================
116
+ # Lazy public API
117
+ # ============================================================================
118
+ _LAZY_EXPORTS = {
119
+ "load_xpcs_data": "xpcsjax.data",
120
+ "fit_nlsq": "xpcsjax.optimization.nlsq",
121
+ "ConfigManager": "xpcsjax.config",
122
+ "generate_nlsq_plots": "xpcsjax.viz",
123
+ "HomodyneModel": "xpcsjax.core",
124
+ "HeterodyneModel": "xpcsjax.core",
125
+ "OptimizationResult": "xpcsjax.optimization.nlsq.results",
126
+ }
127
+
128
+ # TYPE_CHECKING block for IDE / Pyright static visibility. All submodules
129
+ # below now export their public symbol, so the original deferral comment
130
+ # (Tasks 6/11/15/19/20/28) is resolved.
131
+ from typing import TYPE_CHECKING as _TYPE_CHECKING
132
+
133
+ if _TYPE_CHECKING:
134
+ from xpcsjax.config import ConfigManager
135
+ from xpcsjax.core import HeterodyneModel, HomodyneModel
136
+ from xpcsjax.data import load_xpcs_data
137
+ from xpcsjax.optimization.nlsq import fit_nlsq
138
+ from xpcsjax.optimization.nlsq.results import OptimizationResult
139
+ from xpcsjax.viz import generate_nlsq_plots
140
+
141
+
142
+ def __getattr__(name: str): # noqa: D401
143
+ """Lazy attribute loader for the documented public API."""
144
+ if name in _LAZY_EXPORTS:
145
+ module = importlib.import_module(_LAZY_EXPORTS[name])
146
+ attr = getattr(module, name)
147
+ globals()[name] = attr
148
+ return attr
149
+ raise AttributeError(f"module 'xpcsjax' has no attribute {name!r}")
150
+
151
+
152
+ # Literal __all__ for Pyright's reportUnsupportedDunderAll; kept in sync
153
+ # with _LAZY_EXPORTS by the runtime assertion below.
154
+ __all__ = [
155
+ "load_xpcs_data",
156
+ "fit_nlsq",
157
+ "ConfigManager",
158
+ "generate_nlsq_plots",
159
+ "HomodyneModel",
160
+ "HeterodyneModel",
161
+ "OptimizationResult",
162
+ ]
163
+
164
+ assert set(__all__) == set(_LAZY_EXPORTS), (
165
+ "xpcsjax public API mismatch between __all__ and _LAZY_EXPORTS"
166
+ )
@@ -0,0 +1,78 @@
1
+ """Command-line interface for xpcsjax (NLSQ-only XPCS analysis).
2
+
3
+ This subpackage is lazy-loaded: importing :mod:`xpcsjax.cli` does NOT
4
+ import JAX or any of the heavy submodules. Attribute access (e.g.
5
+ ``xpcsjax.cli.main``) triggers the actual import via ``__getattr__``.
6
+
7
+ Mirrors heterodyne's CLI lazy-import surface so test mocks can target
8
+ the same import paths.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ # TYPE_CHECKING block: gives Pyright / mypy / IDEs static visibility of the
16
+ # lazy-exported symbols without paying the import cost at runtime. The
17
+ # ``__getattr__`` below is what actually resolves them at first access.
18
+ if TYPE_CHECKING:
19
+ from xpcsjax.cli.args_parser import create_parser, validate_args
20
+ from xpcsjax.cli.commands import dispatch_command
21
+ from xpcsjax.cli.config_generator import main as config_main
22
+ from xpcsjax.cli.config_handling import apply_cli_overrides, load_and_merge_config
23
+ from xpcsjax.cli.data_pipeline import load_and_validate_data, resolve_phi_angles
24
+ from xpcsjax.cli.main import main
25
+ from xpcsjax.cli.optimization_runner import run_nlsq
26
+ from xpcsjax.cli.plot_dispatch import dispatch_plots
27
+ from xpcsjax.cli.xla_config import configure_xla
28
+
29
+
30
+ _IMPORTS: dict[str, tuple[str, str]] = {
31
+ "main": ("xpcsjax.cli.main", "main"),
32
+ "config_main": ("xpcsjax.cli.config_generator", "main"),
33
+ "configure_xla": ("xpcsjax.cli.xla_config", "configure_xla"),
34
+ "create_parser": ("xpcsjax.cli.args_parser", "create_parser"),
35
+ "validate_args": ("xpcsjax.cli.args_parser", "validate_args"),
36
+ "dispatch_command": ("xpcsjax.cli.commands", "dispatch_command"),
37
+ "load_and_merge_config": (
38
+ "xpcsjax.cli.config_handling",
39
+ "load_and_merge_config",
40
+ ),
41
+ "apply_cli_overrides": (
42
+ "xpcsjax.cli.config_handling",
43
+ "apply_cli_overrides",
44
+ ),
45
+ "load_and_validate_data": (
46
+ "xpcsjax.cli.data_pipeline",
47
+ "load_and_validate_data",
48
+ ),
49
+ "resolve_phi_angles": ("xpcsjax.cli.data_pipeline", "resolve_phi_angles"),
50
+ "run_nlsq": ("xpcsjax.cli.optimization_runner", "run_nlsq"),
51
+ "dispatch_plots": ("xpcsjax.cli.plot_dispatch", "dispatch_plots"),
52
+ }
53
+
54
+
55
+ def __getattr__(name: str) -> Any:
56
+ if name in _IMPORTS:
57
+ module_path, attr = _IMPORTS[name]
58
+ import importlib
59
+
60
+ module = importlib.import_module(module_path)
61
+ return getattr(module, attr)
62
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
63
+
64
+
65
+ __all__ = [
66
+ "main",
67
+ "config_main",
68
+ "configure_xla",
69
+ "create_parser",
70
+ "validate_args",
71
+ "dispatch_command",
72
+ "load_and_merge_config",
73
+ "apply_cli_overrides",
74
+ "load_and_validate_data",
75
+ "resolve_phi_angles",
76
+ "run_nlsq",
77
+ "dispatch_plots",
78
+ ]
@@ -0,0 +1,442 @@
1
+ """Argument parser for the xpcsjax CLI.
2
+
3
+ NLSQ-only by design (see project CLAUDE.md); Bayesian sampling flags from
4
+ the upstream heterodyne CLI are intentionally absent.
5
+
6
+ Parameter override flags map to canonical names per
7
+ ``parameter_registry._MODE_PARAMS``:
8
+
9
+ * ``static_anisotropic`` / ``static_isotropic`` — D0, alpha, D_offset
10
+ * ``laminar_flow`` — D0, alpha, D_offset, gamma_dot_t0, beta,
11
+ gamma_dot_t_offset, phi0
12
+ * ``two_component`` — D0_ref/alpha_ref/D_offset_ref,
13
+ D0_sample/alpha_sample/D_offset_sample, v0, v_beta, v_offset,
14
+ f0..f3, phi0_het. The reference/sample transport params and phi0_het
15
+ have no CLI flag (14 params is too many for the command line); set
16
+ them in the YAML config. The v0/v_beta/v_offset/f0..f3 flags below DO
17
+ apply to two_component.
18
+
19
+ Flags whose canonical name is not in the active mode's parameter set are
20
+ silently ignored by ``config_handling.apply_cli_overrides`` (it
21
+ intersects against ``ConfigManager.get_active_parameters()`` before
22
+ writing the canonical ``initial_parameters`` block).
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ from pathlib import Path
29
+
30
+ _VALID_MODES = ("static_anisotropic", "static_isotropic", "laminar_flow", "two_component")
31
+
32
+
33
+ def create_parser() -> argparse.ArgumentParser:
34
+ """Build the xpcsjax CLI argument parser."""
35
+ parser = argparse.ArgumentParser(
36
+ prog="xpcsjax",
37
+ description=(
38
+ "xpcsjax — JAX-native NLSQ fitting for homodyne / heterodyne XPCS. "
39
+ "Bayesian sampling is permanently out of scope for this package; "
40
+ "use the upstream homodyne / heterodyne packages for that."
41
+ ),
42
+ formatter_class=argparse.RawDescriptionHelpFormatter,
43
+ epilog="""
44
+ Examples:
45
+ # Run NLSQ fit with a YAML config
46
+ xpcsjax --config analysis.yaml
47
+
48
+ # Override the output directory
49
+ xpcsjax --config analysis.yaml --output ./results
50
+
51
+ # Multistart NLSQ with 16 restarts
52
+ xpcsjax --config analysis.yaml --multistart --multistart-n 16
53
+
54
+ # Plot experimental data only (skip optimization)
55
+ xpcsjax --config analysis.yaml --plot-experimental-data
56
+
57
+ # Plot simulated C2 heatmaps from config parameters
58
+ xpcsjax --config analysis.yaml --plot-simulated-data --phi-angles 0,45,90,135
59
+
60
+ # Launch the interactive analysis workbench (GUI) — separate console script
61
+ xpcsjax-gui
62
+
63
+ Exit codes:
64
+ 0 Analysis completed and the optimizer converged (or no convergence
65
+ check applies, e.g. plot-only runs).
66
+ 1 Unhandled exception during analysis. See log for traceback.
67
+ 2 Analysis ran but the optimizer did NOT converge. Output files are
68
+ still written but the fit is not trustworthy.
69
+ 130 Interrupted by the user (Ctrl-C).
70
+ """,
71
+ )
72
+
73
+ # ------------------------------------------------------------------
74
+ # Required: config path
75
+ # ------------------------------------------------------------------
76
+ parser.add_argument(
77
+ "--config",
78
+ "-c",
79
+ type=Path,
80
+ required=True,
81
+ help="Path to YAML configuration file (required).",
82
+ )
83
+
84
+ # ------------------------------------------------------------------
85
+ # Output
86
+ # ------------------------------------------------------------------
87
+ parser.add_argument(
88
+ "--output",
89
+ "-o",
90
+ type=Path,
91
+ default=None,
92
+ help="Output directory (overrides ``output.directory`` in YAML).",
93
+ )
94
+ parser.add_argument(
95
+ "--output-format",
96
+ choices=["json", "npz", "both"],
97
+ default="both",
98
+ help="Format for saved results (default: both).",
99
+ )
100
+
101
+ # ------------------------------------------------------------------
102
+ # Mode / phi
103
+ # ------------------------------------------------------------------
104
+ parser.add_argument(
105
+ "--mode",
106
+ choices=_VALID_MODES,
107
+ default=None,
108
+ help=(f"Force ``analysis_mode`` (overrides YAML). Must be one of {_VALID_MODES}."),
109
+ )
110
+ parser.add_argument(
111
+ "--phi",
112
+ type=float,
113
+ nargs="+",
114
+ default=None,
115
+ help="Phi angles to analyze, in degrees (overrides config).",
116
+ )
117
+
118
+ # ------------------------------------------------------------------
119
+ # NLSQ options
120
+ # ------------------------------------------------------------------
121
+ nlsq_group = parser.add_argument_group("NLSQ options", "Solver and multistart controls.")
122
+ # store_const with default=None (not store_true) so the override layer
123
+ # can distinguish "user did not pass --multistart" (None → leave YAML
124
+ # untouched) from "user passed it" (True). store_true's False default
125
+ # would clobber a YAML ``multi_start.enable: true`` on every run.
126
+ nlsq_group.add_argument(
127
+ "--multistart",
128
+ action="store_const",
129
+ const=True,
130
+ default=None,
131
+ help="Enable LHS multistart for NLSQ (overrides YAML).",
132
+ )
133
+ nlsq_group.add_argument(
134
+ "--no-multistart",
135
+ dest="multistart",
136
+ action="store_const",
137
+ const=False,
138
+ help="Explicitly disable multistart (overrides a YAML enable).",
139
+ )
140
+ nlsq_group.add_argument(
141
+ "--multistart-n",
142
+ type=int,
143
+ default=None,
144
+ help="Number of multistart restarts (default: from config, else 10).",
145
+ )
146
+ nlsq_group.add_argument(
147
+ "--max-iterations",
148
+ type=int,
149
+ default=None,
150
+ help="Maximum trust-region iterations (overrides config).",
151
+ )
152
+ nlsq_group.add_argument(
153
+ "--tolerance",
154
+ type=float,
155
+ default=None,
156
+ help="NLSQ convergence tolerance (overrides config).",
157
+ )
158
+
159
+ # ------------------------------------------------------------------
160
+ # Parameter overrides (highest precedence)
161
+ # CLI > YAML > parameter_registry defaults.
162
+ # Covers the union of all four modes; mode-incompatible flags are
163
+ # silently ignored downstream.
164
+ # ------------------------------------------------------------------
165
+ param_group = parser.add_argument_group(
166
+ "parameter overrides",
167
+ "Override initial parameter values (highest precedence). "
168
+ "Flags for parameters not used by the active mode are ignored.",
169
+ )
170
+ # Core transport (all modes)
171
+ param_group.add_argument(
172
+ "--initial-D0",
173
+ type=float,
174
+ default=None,
175
+ metavar="VAL",
176
+ help="Diffusion prefactor D0 [Ų/s^α].",
177
+ )
178
+ param_group.add_argument(
179
+ "--initial-alpha",
180
+ type=float,
181
+ default=None,
182
+ metavar="VAL",
183
+ help="Transport exponent alpha.",
184
+ )
185
+ param_group.add_argument(
186
+ "--initial-D-offset",
187
+ type=float,
188
+ default=None,
189
+ metavar="VAL",
190
+ help="Transport offset D_offset [Ų/s].",
191
+ )
192
+ # Laminar flow / two-component velocity
193
+ param_group.add_argument(
194
+ "--initial-gamma-dot-t0",
195
+ type=float,
196
+ default=None,
197
+ metavar="VAL",
198
+ help="Shear rate prefactor (laminar_flow).",
199
+ )
200
+ param_group.add_argument(
201
+ "--initial-gamma-dot-t-offset",
202
+ type=float,
203
+ default=None,
204
+ metavar="VAL",
205
+ help="Shear rate offset (laminar_flow).",
206
+ )
207
+ param_group.add_argument(
208
+ "--initial-beta",
209
+ type=float,
210
+ default=None,
211
+ metavar="VAL",
212
+ help="Velocity exponent beta (laminar_flow).",
213
+ )
214
+ param_group.add_argument(
215
+ "--initial-v-beta",
216
+ type=float,
217
+ default=None,
218
+ metavar="VAL",
219
+ help="Velocity exponent v_beta (two_component).",
220
+ )
221
+ param_group.add_argument(
222
+ "--initial-v0",
223
+ type=float,
224
+ default=None,
225
+ metavar="VAL",
226
+ help="Velocity prefactor v0 (two_component).",
227
+ )
228
+ param_group.add_argument(
229
+ "--initial-v-offset",
230
+ type=float,
231
+ default=None,
232
+ metavar="VAL",
233
+ help="Velocity offset v_offset (two_component).",
234
+ )
235
+ # Angle parameters
236
+ param_group.add_argument(
237
+ "--initial-phi0",
238
+ type=float,
239
+ default=None,
240
+ metavar="VAL",
241
+ help="Flow angle offset phi0 [degrees].",
242
+ )
243
+ # Two-component Fourier amplitudes (per-angle fraction)
244
+ param_group.add_argument(
245
+ "--initial-f0",
246
+ type=float,
247
+ default=None,
248
+ metavar="VAL",
249
+ help="Sample fraction amplitude f0 (two_component).",
250
+ )
251
+ param_group.add_argument(
252
+ "--initial-f1",
253
+ type=float,
254
+ default=None,
255
+ metavar="VAL",
256
+ help="Fourier coefficient f1 (two_component).",
257
+ )
258
+ param_group.add_argument(
259
+ "--initial-f2",
260
+ type=float,
261
+ default=None,
262
+ metavar="VAL",
263
+ help="Fourier coefficient f2 (two_component).",
264
+ )
265
+ param_group.add_argument(
266
+ "--initial-f3",
267
+ type=float,
268
+ default=None,
269
+ metavar="VAL",
270
+ help="Fourier coefficient f3 (two_component).",
271
+ )
272
+ # Note: per-angle scaling (contrast/offset) is not a single-value CLI
273
+ # override — there is one pair per phi angle. Set scaling in the YAML
274
+ # config's per-angle scaling block instead.
275
+
276
+ # ------------------------------------------------------------------
277
+ # Verbosity
278
+ # ------------------------------------------------------------------
279
+ parser.add_argument(
280
+ "--verbose",
281
+ "-v",
282
+ action="count",
283
+ default=0,
284
+ help="Increase verbosity (-v, -vv).",
285
+ )
286
+ parser.add_argument(
287
+ "--quiet",
288
+ "-q",
289
+ action="store_true",
290
+ help="Suppress all output except errors.",
291
+ )
292
+
293
+ # ------------------------------------------------------------------
294
+ # Performance
295
+ # ------------------------------------------------------------------
296
+ parser.add_argument(
297
+ "--threads",
298
+ type=int,
299
+ default=None,
300
+ help="CPU thread count for XLA (default: auto).",
301
+ )
302
+ parser.add_argument(
303
+ "--no-jit",
304
+ action="store_true",
305
+ help="Disable JIT compilation (for debugging only — much slower).",
306
+ )
307
+
308
+ # ------------------------------------------------------------------
309
+ # Plotting
310
+ # ------------------------------------------------------------------
311
+ plot_group = parser.add_mutually_exclusive_group()
312
+ plot_group.add_argument(
313
+ "--plot",
314
+ dest="plot",
315
+ action="store_true",
316
+ default=True,
317
+ help="Generate plots after fitting (default).",
318
+ )
319
+ plot_group.add_argument(
320
+ "--no-plot",
321
+ dest="plot",
322
+ action="store_false",
323
+ help="Skip plot generation.",
324
+ )
325
+ parser.add_argument(
326
+ "--save-plots",
327
+ action="store_true",
328
+ help="Save fit-comparison plots to the output directory.",
329
+ )
330
+ parser.add_argument(
331
+ "--plotting-backend",
332
+ choices=["auto", "matplotlib", "datashader"],
333
+ default="auto",
334
+ help=(
335
+ "Plotting backend: auto (Datashader if installed), "
336
+ "matplotlib, or datashader (default: %(default)s)."
337
+ ),
338
+ )
339
+ parser.add_argument(
340
+ "--parallel-plots",
341
+ action="store_true",
342
+ help="Generate plots in parallel via multiprocessing (Datashader path).",
343
+ )
344
+ parser.add_argument(
345
+ "--phi-angles",
346
+ type=str,
347
+ default=None,
348
+ help=("Comma-separated phi angles in degrees for simulated data (e.g. '0,45,90,135')."),
349
+ )
350
+
351
+ # Standalone plot modes (skip optimization)
352
+ parser.add_argument(
353
+ "--plot-experimental-data",
354
+ action="store_true",
355
+ help="Plot experimental data for QC (skip optimization).",
356
+ )
357
+ parser.add_argument(
358
+ "--plot-simulated-data",
359
+ action="store_true",
360
+ help="Plot simulated C2 heatmaps from config parameters (skip optimization).",
361
+ )
362
+ parser.add_argument(
363
+ "--contrast",
364
+ type=float,
365
+ default=0.3,
366
+ help="Contrast for simulated data (default: %(default)s; requires --plot-simulated-data).",
367
+ )
368
+ parser.add_argument(
369
+ "--offset-sim",
370
+ type=float,
371
+ default=1.0,
372
+ help="Offset for simulated data (default: %(default)s; requires --plot-simulated-data).",
373
+ )
374
+
375
+ # ------------------------------------------------------------------
376
+ # Version
377
+ # ------------------------------------------------------------------
378
+ _add_version_arg(parser)
379
+
380
+ return parser
381
+
382
+
383
+ def _add_version_arg(parser: argparse.ArgumentParser) -> None:
384
+ """Add ``--version`` with a best-effort version resolution."""
385
+ try:
386
+ import importlib.metadata as _md
387
+
388
+ version = _md.version("xpcsjax")
389
+ except Exception: # pragma: no cover — uninstalled / dev tree
390
+ try:
391
+ from xpcsjax import __version__ as version
392
+ except Exception:
393
+ version = "unknown"
394
+ parser.add_argument(
395
+ "--version",
396
+ action="version",
397
+ version=f"%(prog)s {version}",
398
+ )
399
+
400
+
401
+ def validate_args(args: argparse.Namespace) -> list[str]:
402
+ """Light validation pass. Returns non-fatal warning strings.
403
+
404
+ Raises ``FileNotFoundError`` for unrecoverable issues (missing config).
405
+ """
406
+ warnings: list[str] = []
407
+
408
+ if not args.config.exists():
409
+ raise FileNotFoundError(f"Configuration file not found: {args.config}")
410
+
411
+ if args.verbose > 0 and args.quiet:
412
+ warnings.append("Both --verbose and --quiet specified; --quiet wins.")
413
+ args.verbose = 0
414
+
415
+ phi_angles_str: str | None = getattr(args, "phi_angles", None)
416
+ if phi_angles_str is not None:
417
+ try:
418
+ [float(x.strip()) for x in phi_angles_str.split(",")]
419
+ except ValueError:
420
+ warnings.append(
421
+ f"--phi-angles must be comma-separated numbers "
422
+ f"(e.g. '0,45,90,135'); got: {phi_angles_str!r}"
423
+ )
424
+
425
+ if args.plot_experimental_data and args.plot_simulated_data:
426
+ warnings.append(
427
+ "Both --plot-experimental-data and --plot-simulated-data given; running both passes."
428
+ )
429
+
430
+ if args.multistart_n is not None and args.multistart_n <= 0:
431
+ warnings.append(
432
+ f"--multistart-n must be positive; got {args.multistart_n}. "
433
+ "Falling back to config value."
434
+ )
435
+ args.multistart_n = None
436
+
437
+ return warnings
438
+
439
+
440
+ build_parser = create_parser # uniform factory name used by completion_spec
441
+
442
+ __all__ = ["build_parser", "create_parser", "validate_args"]