hwoutils 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,213 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # Auto-generated version file (hatch-vcs)
210
+ *_version.py
211
+ .DS_Store
212
+ input/
213
+ output/
@@ -0,0 +1,20 @@
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: "v6.0.0"
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: name-tests-test
7
+ - id: end-of-file-fixer
8
+ - repo: https://github.com/astral-sh/ruff-pre-commit
9
+ rev: v0.14.10
10
+ hooks:
11
+ # Run the linter.
12
+ - id: ruff
13
+ # Run the formatter.
14
+ - id: ruff-format
15
+ - repo: https://github.com/compilerla/conventional-pre-commit
16
+ rev: v4.3.0
17
+ hooks:
18
+ - id: conventional-pre-commit
19
+ stages: [commit-msg]
20
+ args: []
@@ -0,0 +1,20 @@
1
+ # Required
2
+ version: 2
3
+
4
+ # Set the OS, Python version and other tools you might need
5
+ build:
6
+ os: ubuntu-22.04
7
+ tools:
8
+ python: "3.12"
9
+
10
+ python:
11
+ install:
12
+ - method: pip
13
+ path: .
14
+ extra_requirements:
15
+ # Install with the [docs] flag for sphinx docs specific extensions
16
+ - docs
17
+
18
+ # Build documentation in the "docs/" directory with Sphinx
19
+ sphinx:
20
+ configuration: docs/conf.py
@@ -0,0 +1,19 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ ## 1.0.0 (2026-02-21)
6
+
7
+
8
+ ### Features
9
+
10
+ * Initial migration of repeated code ([9150db9](https://github.com/CoreySpohn/hwoutils/commit/9150db9bac1368570d4177504377862668e5c164))
11
+
12
+ ## [Unreleased]
13
+
14
+ ### Added
15
+
16
+ - `constants` module with consolidated physical constants from orbix and coronagraphoto
17
+ - `conversions` module with JAX-native unit conversion functions
18
+ - `map_coordinates` module with cubic spline interpolation (from JAX PR #14218)
19
+ - `transforms` module with `resample_flux`, `ccw_rotation_matrix`, and `shift_image`
hwoutils-1.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Corey Spohn
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,113 @@
1
+ Metadata-Version: 2.4
2
+ Name: hwoutils
3
+ Version: 1.0.0
4
+ Summary: Shared JAX-based utilities for the HWO direct imaging simulation suite — constants, conversions, and image transforms.
5
+ Project-URL: Homepage, https://github.com/CoreySpohn/hwoutils
6
+ Project-URL: Issues, https://github.com/CoreySpohn/hwoutils/issues
7
+ Author-email: Corey Spohn <corey.a.spohn@nasa.gov>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2026 Corey Spohn
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Classifier: Development Status :: 3 - Alpha
31
+ Classifier: Intended Audience :: Science/Research
32
+ Classifier: License :: OSI Approved :: MIT License
33
+ Classifier: Programming Language :: Python :: 3
34
+ Classifier: Topic :: Scientific/Engineering :: Astronomy
35
+ Requires-Python: >=3.10
36
+ Requires-Dist: jax
37
+ Requires-Dist: jaxlib
38
+ Provides-Extra: dev
39
+ Requires-Dist: pre-commit; extra == 'dev'
40
+ Requires-Dist: ruff; extra == 'dev'
41
+ Provides-Extra: docs
42
+ Requires-Dist: ipython; extra == 'docs'
43
+ Requires-Dist: matplotlib; extra == 'docs'
44
+ Requires-Dist: myst-nb; extra == 'docs'
45
+ Requires-Dist: sphinx; extra == 'docs'
46
+ Requires-Dist: sphinx-autoapi; extra == 'docs'
47
+ Requires-Dist: sphinx-autodoc-typehints; extra == 'docs'
48
+ Requires-Dist: sphinx-book-theme; extra == 'docs'
49
+ Provides-Extra: test
50
+ Requires-Dist: hypothesis; extra == 'test'
51
+ Requires-Dist: nox; extra == 'test'
52
+ Requires-Dist: pytest; extra == 'test'
53
+ Requires-Dist: pytest-cov; extra == 'test'
54
+ Description-Content-Type: text/markdown
55
+
56
+ <p align="center">
57
+ <a href="https://pypi.org/project/hwoutils/"><img src="https://img.shields.io/pypi/v/hwoutils.svg?style=flat-square&logo=pypi" alt="PyPI"/></a>
58
+ <a href="https://hwoutils.readthedocs.io"><img src="https://readthedocs.org/projects/hwoutils/badge/?version=latest&style=flat-square" alt="Documentation Status"/></a>
59
+ <a href="https://github.com/CoreySpohn/hwoutils/actions/workflows/tests.yml"><img src="https://img.shields.io/github/actions/workflow/status/CoreySpohn/hwoutils/tests.yml?branch=main&style=flat-square&logo=github&label=tests" alt="Tests"/></a>
60
+ <a href="https://github.com/CoreySpohn/hwoutils/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="License"></a>
61
+ <a href="https://github.com/CoreySpohn/hwoutils"><img src="https://img.shields.io/badge/python-3.10%20|%203.11%20|%203.12-blue.svg?style=flat-square&logo=python" alt="Python Versions"></a>
62
+ </p>
63
+
64
+ ---
65
+
66
+ # hwoutils
67
+
68
+ **hwoutils** is the shared utility foundation for the HWO direct imaging simulation suite. It provides JAX-native physical constants, unit conversions, and flux-conserving image transforms used across:
69
+
70
+ - **[yippy](https://github.com/CoreySpohn/yippy)** — Coronagraph performance modeling
71
+ - **[orbix](https://github.com/CoreySpohn/orbix)** — Orbital dynamics and target scheduling
72
+ - **[coronagraphoto](https://github.com/CoreySpohn/coronagraphoto)** — Image simulation for coronagraphic observations
73
+ - **[coronalyze](https://github.com/CoreySpohn/coronalyze)** — Post-processing and SNR analysis
74
+ - **[hwosim](https://github.com/CoreySpohn/hwosim)** — End-to-end mission simulations
75
+
76
+ ## Key Features
77
+
78
+ - **Physical Constants** — Single source of truth for SI constants, conversion factors, and astronomical quantities
79
+ - **Unit Conversions** — Pure JAX conversion functions (angular, flux, distance, time) with zero astropy overhead
80
+ - **Image Transforms** — Flux-conserving resampling, sub-pixel shifts, and cubic spline interpolation
81
+ - **JAX-Native** — All operations are JIT-compilable, differentiable, and GPU-accelerated
82
+
83
+ ## Installation
84
+
85
+ With [uv](https://docs.astral.sh/uv/) (recommended):
86
+
87
+ ```bash
88
+ uv pip install hwoutils
89
+ ```
90
+
91
+ Or with pip:
92
+
93
+ ```bash
94
+ pip install hwoutils
95
+ ```
96
+
97
+ ## Quick Start
98
+
99
+ ```python
100
+ from hwoutils import constants as const
101
+ from hwoutils import conversions as conv
102
+ from hwoutils.transforms import resample_flux
103
+
104
+ # Convert 5 arcsec to lambda/D for a 6m telescope at 550nm
105
+ sep_lod = conv.arcsec_to_lambda_d(5.0, 550.0, 6.0)
106
+
107
+ # Flux-conserving PSF resampling
108
+ resampled = resample_flux(psf, pixscale_src=0.01, pixscale_tgt=0.1, shape_tgt=(64, 64))
109
+ ```
110
+
111
+ ## Documentation
112
+
113
+ Full documentation is available at [hwoutils.readthedocs.io](https://hwoutils.readthedocs.io).
@@ -0,0 +1,58 @@
1
+ <p align="center">
2
+ <a href="https://pypi.org/project/hwoutils/"><img src="https://img.shields.io/pypi/v/hwoutils.svg?style=flat-square&logo=pypi" alt="PyPI"/></a>
3
+ <a href="https://hwoutils.readthedocs.io"><img src="https://readthedocs.org/projects/hwoutils/badge/?version=latest&style=flat-square" alt="Documentation Status"/></a>
4
+ <a href="https://github.com/CoreySpohn/hwoutils/actions/workflows/tests.yml"><img src="https://img.shields.io/github/actions/workflow/status/CoreySpohn/hwoutils/tests.yml?branch=main&style=flat-square&logo=github&label=tests" alt="Tests"/></a>
5
+ <a href="https://github.com/CoreySpohn/hwoutils/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="License"></a>
6
+ <a href="https://github.com/CoreySpohn/hwoutils"><img src="https://img.shields.io/badge/python-3.10%20|%203.11%20|%203.12-blue.svg?style=flat-square&logo=python" alt="Python Versions"></a>
7
+ </p>
8
+
9
+ ---
10
+
11
+ # hwoutils
12
+
13
+ **hwoutils** is the shared utility foundation for the HWO direct imaging simulation suite. It provides JAX-native physical constants, unit conversions, and flux-conserving image transforms used across:
14
+
15
+ - **[yippy](https://github.com/CoreySpohn/yippy)** — Coronagraph performance modeling
16
+ - **[orbix](https://github.com/CoreySpohn/orbix)** — Orbital dynamics and target scheduling
17
+ - **[coronagraphoto](https://github.com/CoreySpohn/coronagraphoto)** — Image simulation for coronagraphic observations
18
+ - **[coronalyze](https://github.com/CoreySpohn/coronalyze)** — Post-processing and SNR analysis
19
+ - **[hwosim](https://github.com/CoreySpohn/hwosim)** — End-to-end mission simulations
20
+
21
+ ## Key Features
22
+
23
+ - **Physical Constants** — Single source of truth for SI constants, conversion factors, and astronomical quantities
24
+ - **Unit Conversions** — Pure JAX conversion functions (angular, flux, distance, time) with zero astropy overhead
25
+ - **Image Transforms** — Flux-conserving resampling, sub-pixel shifts, and cubic spline interpolation
26
+ - **JAX-Native** — All operations are JIT-compilable, differentiable, and GPU-accelerated
27
+
28
+ ## Installation
29
+
30
+ With [uv](https://docs.astral.sh/uv/) (recommended):
31
+
32
+ ```bash
33
+ uv pip install hwoutils
34
+ ```
35
+
36
+ Or with pip:
37
+
38
+ ```bash
39
+ pip install hwoutils
40
+ ```
41
+
42
+ ## Quick Start
43
+
44
+ ```python
45
+ from hwoutils import constants as const
46
+ from hwoutils import conversions as conv
47
+ from hwoutils.transforms import resample_flux
48
+
49
+ # Convert 5 arcsec to lambda/D for a 6m telescope at 550nm
50
+ sep_lod = conv.arcsec_to_lambda_d(5.0, 550.0, 6.0)
51
+
52
+ # Flux-conserving PSF resampling
53
+ resampled = resample_flux(psf, pixscale_src=0.01, pixscale_tgt=0.1, shape_tgt=(64, 64))
54
+ ```
55
+
56
+ ## Documentation
57
+
58
+ Full documentation is available at [hwoutils.readthedocs.io](https://hwoutils.readthedocs.io).
@@ -0,0 +1,62 @@
1
+ [build-system]
2
+ requires = ['hatchling', "hatch-fancy-pypi-readme", "hatch-vcs"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "hwoutils"
7
+ description = "Shared JAX-based utilities for the HWO direct imaging simulation suite — constants, conversions, and image transforms."
8
+ authors = [{ name = "Corey Spohn", email = "corey.a.spohn@nasa.gov" }]
9
+ dynamic = ['readme', 'version']
10
+ requires-python = ">=3.10"
11
+ license = { file = "LICENSE" }
12
+ classifiers = [
13
+ "Development Status :: 3 - Alpha",
14
+ "Intended Audience :: Science/Research",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Programming Language :: Python :: 3",
17
+ "Topic :: Scientific/Engineering :: Astronomy",
18
+ ]
19
+ dependencies = [
20
+ "jax",
21
+ "jaxlib",
22
+ ]
23
+ [project.optional-dependencies]
24
+ dev = ["ruff", "pre-commit"]
25
+ docs = [
26
+ "sphinx",
27
+ "myst-nb",
28
+ "sphinx-book-theme",
29
+ "sphinx-autoapi",
30
+ "sphinx_autodoc_typehints",
31
+ "ipython",
32
+ "matplotlib",
33
+ ]
34
+ test = ["nox", "pytest", "hypothesis", "pytest-cov"]
35
+
36
+ [project.urls]
37
+ Homepage = "https://github.com/CoreySpohn/hwoutils"
38
+ Issues = "https://github.com/CoreySpohn/hwoutils/issues"
39
+
40
+ [tool.hatch.version]
41
+ source = "vcs"
42
+
43
+ [tool.hatch.build.hooks.vcs]
44
+ version-file = "src/hwoutils/_version.py"
45
+
46
+ [tool.hatch.metadata.hooks.fancy-pypi-readme]
47
+ content-type = "text/markdown"
48
+
49
+ [[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]]
50
+ path = "README.md"
51
+
52
+ [tool.ruff.lint]
53
+ select = ["D", "I"]
54
+
55
+ [tool.ruff.lint.pydocstyle]
56
+ convention = "google"
57
+
58
+ [tool.hatch.build.targets.wheel]
59
+ packages = ["src/hwoutils"]
60
+
61
+ [tool.hatch.build.targets.sdist]
62
+ exclude = ["/scripts", "/docs", "/tests", "/.github"]
@@ -0,0 +1,6 @@
1
+ """hwoutils — Shared JAX-based utilities for the HWO simulation suite."""
2
+
3
+ try:
4
+ from hwoutils._version import version as __version__
5
+ except ModuleNotFoundError:
6
+ __version__ = "unknown"
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.0.0'
32
+ __version_tuple__ = version_tuple = (1, 0, 0)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,81 @@
1
+ """Physical constants and unit conversion factors.
2
+
3
+ Single source of truth for all constants used across the HWO simulation suite
4
+ (yippy, orbix, coronagraphoto, coronalyze, hwosim). All values are plain floats
5
+ or JAX arrays — no astropy.units dependency.
6
+ """
7
+
8
+ import jax.numpy as jnp
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Mathematical constants
12
+ # ---------------------------------------------------------------------------
13
+ two_pi = 2 * jnp.pi
14
+ pi_over_2 = jnp.pi / 2
15
+ eps = jnp.finfo(jnp.float32).eps
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Fundamental physical constants (SI)
19
+ # ---------------------------------------------------------------------------
20
+ h = 6.62607015e-34 # Planck constant [J·s]
21
+ c = 299792458.0 # Speed of light [m/s]
22
+ k_B = 1.380649e-23 # Boltzmann constant [J/K]
23
+ sigma_SB = 5.670374419e-8 # Stefan-Boltzmann constant [W·m⁻²·K⁻⁴]
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Gravitational constant
27
+ # ---------------------------------------------------------------------------
28
+ G_si = 6.67430e-11 # [m³·kg⁻¹·s⁻²]
29
+ G = 1.488185170234519e-34 # [AU³·kg⁻¹·d⁻²] — for orbital dynamics
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Flux conversion
33
+ # ---------------------------------------------------------------------------
34
+ Jy = 1e-26 # 1 Jansky [W·m⁻²·Hz⁻¹]
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Length conversions
38
+ # ---------------------------------------------------------------------------
39
+ nm2m = 1e-9
40
+ m2nm = 1e9
41
+ um2m = 1e-6
42
+ m2um = 1e6
43
+ nm2um = 1e-3
44
+ um2nm = 1e3
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Distance conversions
48
+ # ---------------------------------------------------------------------------
49
+ AU2m = 1.495978707e11
50
+ m2AU = 6.684587122268445e-12
51
+ pc2m = 3.0857e16
52
+ m2pc = 1.0 / pc2m
53
+ pc2AU = 2.062648062470964e05
54
+ Rearth2m = 6.371e6
55
+ Rearth2AU = 4.263496512454037e-05
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Mass conversions
59
+ # ---------------------------------------------------------------------------
60
+ Msun2kg = 1.988409870698051e30
61
+ Mearth2kg = 5.972167867791379e24
62
+ Mjup2kg = 1.898124571735094e27
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Angular conversions
66
+ # ---------------------------------------------------------------------------
67
+ rad2arcsec = 206264.80624709636
68
+ arcsec2rad = 4.84813681109536e-06
69
+ mas2arcsec = 1e-3
70
+ arcsec2mas = 1e3
71
+ deg2rad = jnp.pi / 180.0
72
+ rad2deg = 180.0 / jnp.pi
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Time conversions
76
+ # ---------------------------------------------------------------------------
77
+ yr2s = 365.25 * 86400.0
78
+ s2yr = 1.0 / yr2s
79
+ d2s = 86400.0
80
+ s2d = 1.157407407407407e-05
81
+ J2000_JD = 2451545.0
@@ -0,0 +1,302 @@
1
+ """Unit conversion functions using centralized constants.
2
+
3
+ Pure JAX implementations — no astropy dependency. Functions are intentionally
4
+ NOT JIT-compiled so JAX can fuse them into larger computation graphs.
5
+ """
6
+
7
+ import jax.numpy as jnp
8
+
9
+ from hwoutils import constants as const
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Flux conversions
13
+ # ---------------------------------------------------------------------------
14
+
15
+
16
+ def jy_to_photons_per_nm_per_m2(flux_jy, wavelength_nm):
17
+ """Convert flux density from Janskys to photons/s/nm/m².
18
+
19
+ Args:
20
+ flux_jy: Flux density in Janskys.
21
+ wavelength_nm: Wavelength in nanometers.
22
+
23
+ Returns:
24
+ Flux density in photons/s/nm/m².
25
+ """
26
+ return flux_jy * const.Jy / (wavelength_nm * const.h)
27
+
28
+
29
+ def photons_per_nm_per_m2_to_jy(flux_phot, wavelength_nm):
30
+ """Convert flux density from photons/s/nm/m² to Janskys.
31
+
32
+ Args:
33
+ flux_phot: Flux density in photons/s/nm/m².
34
+ wavelength_nm: Wavelength in nanometers.
35
+
36
+ Returns:
37
+ Flux density in Janskys.
38
+ """
39
+ return flux_phot * (wavelength_nm * const.h) / const.Jy
40
+
41
+
42
+ def mag_per_arcsec2_to_jy_per_arcsec2(mag_per_arcsec2):
43
+ """Convert surface brightness from mag/arcsec² to Jy/arcsec² (AB).
44
+
45
+ Args:
46
+ mag_per_arcsec2: Surface brightness in magnitudes per arcsec².
47
+
48
+ Returns:
49
+ Surface brightness in Jy/arcsec².
50
+ """
51
+ f0_jy = 3631.0 # AB magnitude zero point
52
+ return f0_jy * 10 ** (-0.4 * mag_per_arcsec2)
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Length conversions
57
+ # ---------------------------------------------------------------------------
58
+
59
+
60
+ def nm_to_um(length_nm):
61
+ """Convert nanometers to micrometers."""
62
+ return length_nm * const.nm2um
63
+
64
+
65
+ def um_to_nm(length_um):
66
+ """Convert micrometers to nanometers."""
67
+ return length_um * const.um2nm
68
+
69
+
70
+ def au_to_m(length_au):
71
+ """Convert AU to meters."""
72
+ return length_au * const.AU2m
73
+
74
+
75
+ def m_to_au(length_m):
76
+ """Convert meters to AU."""
77
+ return length_m * const.m2AU
78
+
79
+
80
+ def Rearth_to_m(length_Rearth):
81
+ """Convert Earth radii to meters."""
82
+ return length_Rearth * const.Rearth2m
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Velocity conversions
87
+ # ---------------------------------------------------------------------------
88
+
89
+
90
+ def au_per_yr_to_m_per_s(velocity_au_per_yr):
91
+ """Convert AU/yr to m/s."""
92
+ return velocity_au_per_yr * const.AU2m / const.yr2s
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Angular conversions
97
+ # ---------------------------------------------------------------------------
98
+
99
+
100
+ def arcsec_to_rad(angle_arcsec):
101
+ """Convert arcseconds to radians."""
102
+ return angle_arcsec * const.arcsec2rad
103
+
104
+
105
+ def rad_to_arcsec(angle_rad):
106
+ """Convert radians to arcseconds."""
107
+ return angle_rad * const.rad2arcsec
108
+
109
+
110
+ def mas_to_arcsec(angle_mas):
111
+ """Convert milliarcseconds to arcseconds."""
112
+ return angle_mas * const.mas2arcsec
113
+
114
+
115
+ def arcsec_to_mas(angle_arcsec):
116
+ """Convert arcseconds to milliarcseconds."""
117
+ return angle_arcsec * const.arcsec2mas
118
+
119
+
120
+ def arcsec_to_lambda_d(angle_arcsec, wavelength_nm, diameter_m):
121
+ """Convert angular separation to lambda/D units.
122
+
123
+ Args:
124
+ angle_arcsec: Angular separation in arcseconds.
125
+ wavelength_nm: Wavelength in nanometers.
126
+ diameter_m: Telescope diameter in meters.
127
+
128
+ Returns:
129
+ Angular separation in lambda/D.
130
+ """
131
+ angle_rad = angle_arcsec * const.arcsec2rad
132
+ wavelength_m = wavelength_nm * const.nm2m
133
+ lambda_d_rad = wavelength_m / diameter_m
134
+ return angle_rad / lambda_d_rad
135
+
136
+
137
+ def lambda_d_to_arcsec(angle_lambda_d, wavelength_nm, diameter_m):
138
+ """Convert lambda/D units to angular separation in arcseconds.
139
+
140
+ Args:
141
+ angle_lambda_d: Angular separation in lambda/D.
142
+ wavelength_nm: Wavelength in nanometers.
143
+ diameter_m: Telescope diameter in meters.
144
+
145
+ Returns:
146
+ Angular separation in arcseconds.
147
+ """
148
+ wavelength_m = wavelength_nm * const.nm2m
149
+ lambda_d_rad = wavelength_m / diameter_m
150
+ angle_rad = angle_lambda_d * lambda_d_rad
151
+ return angle_rad * const.rad2arcsec
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # Mass conversions
156
+ # ---------------------------------------------------------------------------
157
+
158
+
159
+ def Msun_to_kg(mass_solar):
160
+ """Convert solar masses to kilograms."""
161
+ return mass_solar * const.Msun2kg
162
+
163
+
164
+ def Mearth_to_kg(mass_earth):
165
+ """Convert Earth masses to kilograms."""
166
+ return mass_earth * const.Mearth2kg
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # Distance conversions
171
+ # ---------------------------------------------------------------------------
172
+
173
+
174
+ def au_to_arcsec(distance_au, distance_pc):
175
+ """Convert physical distance in AU to angular separation.
176
+
177
+ Args:
178
+ distance_au: Physical distance in AU.
179
+ distance_pc: Distance to system in parsecs.
180
+
181
+ Returns:
182
+ Angular separation in arcseconds.
183
+ """
184
+ return distance_au / distance_pc
185
+
186
+
187
+ def arcsec_to_au(angle_arcsec, distance_pc):
188
+ """Convert angular separation to physical distance.
189
+
190
+ Args:
191
+ angle_arcsec: Angular separation in arcseconds.
192
+ distance_pc: Distance to system in parsecs.
193
+
194
+ Returns:
195
+ Physical distance in AU.
196
+ """
197
+ return angle_arcsec * distance_pc
198
+
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Time conversions
202
+ # ---------------------------------------------------------------------------
203
+
204
+
205
+ def years_to_days(time_years):
206
+ """Convert years to days."""
207
+ return time_years * 365.25
208
+
209
+
210
+ def days_to_years(time_days):
211
+ """Convert days to years."""
212
+ return time_days / 365.25
213
+
214
+
215
+ def is_leap_year(year):
216
+ """Determine if a year is a leap year.
217
+
218
+ Args:
219
+ year: The year to check.
220
+
221
+ Returns:
222
+ True if the year is a leap year.
223
+ """
224
+ return (year % 4 == 0) & ((year % 100 != 0) | (year % 400 == 0))
225
+
226
+
227
+ def days_in_year(year):
228
+ """Return the number of days in a year (365 or 366).
229
+
230
+ Args:
231
+ year: The year to check.
232
+
233
+ Returns:
234
+ Number of days.
235
+ """
236
+ return 365 + is_leap_year(year)
237
+
238
+
239
+ def gregorian_to_jd(year, month, day):
240
+ """Convert a Gregorian date to a Julian day.
241
+
242
+ Args:
243
+ year: The year.
244
+ month: The month.
245
+ day: The day.
246
+
247
+ Returns:
248
+ The Julian day.
249
+ """
250
+ a = jnp.floor((14 - month) / 12)
251
+ y = year + 4800 - a
252
+ m = month + 12 * a - 3
253
+ jdn = (
254
+ day
255
+ + jnp.floor((153 * m + 2) / 5)
256
+ + 365 * y
257
+ + jnp.floor(y / 4)
258
+ - jnp.floor(y / 100)
259
+ + jnp.floor(y / 400)
260
+ - 32045
261
+ )
262
+ return jdn - 0.5
263
+
264
+
265
+ def jd_to_decimal_year(jd):
266
+ """Convert a Julian day to a decimal year.
267
+
268
+ Args:
269
+ jd: The Julian day.
270
+
271
+ Returns:
272
+ The decimal year.
273
+ """
274
+ year_approx = 1970.0 + (jd - 2440587.5) / 365.2425
275
+ year = jnp.floor(year_approx)
276
+
277
+ jd_start = gregorian_to_jd(year, 1, 1)
278
+ jd_end = gregorian_to_jd(year + 1, 1, 1)
279
+
280
+ year = jnp.where(jd < jd_start, year - 1, year)
281
+ jd_start = gregorian_to_jd(year, 1, 1)
282
+ jd_end = gregorian_to_jd(year + 1, 1, 1)
283
+
284
+ return year + (jd - jd_start) / (jd_end - jd_start)
285
+
286
+
287
+ def decimal_year_to_jd(decimal_year):
288
+ """Convert a decimal year to a Julian day.
289
+
290
+ Args:
291
+ decimal_year: The decimal year.
292
+
293
+ Returns:
294
+ The Julian day.
295
+ """
296
+ year = jnp.floor(decimal_year)
297
+ year_fraction = decimal_year - year
298
+
299
+ jd_start = gregorian_to_jd(year, 1, 1)
300
+ jd_end = gregorian_to_jd(year + 1, 1, 1)
301
+
302
+ return jd_start + year_fraction * (jd_end - jd_start)
@@ -0,0 +1,243 @@
1
+ """Cubic spline interpolation for JAX.
2
+
3
+ This module contains an implementation of ``map_coordinates`` with cubic spline
4
+ interpolation. It is adapted from the JAX project (PR #14218 by Louis Desdoigts)
5
+ and is licensed under the Apache 2.0 license.
6
+
7
+ Original JAX source:
8
+ https://github.com/google/jax/blob/main/jax/_src/scipy/ndimage.py
9
+ """
10
+
11
+ from collections.abc import Callable, Sequence
12
+ from typing import Dict
13
+
14
+ import jax
15
+ import jax.numpy as jnp
16
+ from jax import lax, vmap
17
+ from jax._src import api, util
18
+ from jax._src.numpy.linalg import inv
19
+ from jax._src.typing import Array, ArrayLike
20
+ from jax._src.util import safe_zip as zip
21
+
22
+
23
+ def _nonempty_prod(arrs: Sequence[Array]) -> Array:
24
+ return arrs[0] if len(arrs) == 1 else lax.reduce(arrs, lax.mul)
25
+
26
+
27
+ def _nonempty_sum(arrs: Sequence[Array]) -> Array:
28
+ return arrs[0] if len(arrs) == 1 else lax.reduce(arrs, lax.add)
29
+
30
+
31
+ def _mirror_index_fixer(index: Array, size: int) -> Array:
32
+ s = size - 1
33
+ return jnp.abs((index + s) % (2 * s) - s)
34
+
35
+
36
+ def _reflect_index_fixer(index: Array, size: int) -> Array:
37
+ return jnp.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2)
38
+
39
+
40
+ _INDEX_FIXERS: Dict[str, Callable[[Array, int], Array]] = {
41
+ "constant": lambda index, size: index,
42
+ "nearest": lambda index, size: jnp.clip(index, 0, size - 1),
43
+ "wrap": lambda index, size: index % size,
44
+ "mirror": _mirror_index_fixer,
45
+ "reflect": _reflect_index_fixer,
46
+ }
47
+
48
+
49
+ def _round_half_away_from_zero(a: Array) -> Array:
50
+ return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
51
+
52
+
53
+ def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
54
+ index = _round_half_away_from_zero(coordinate).astype(jnp.int32)
55
+ weight = coordinate.dtype.type(1)
56
+ return [(index, weight)]
57
+
58
+
59
+ def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
60
+ lower = jnp.floor(coordinate)
61
+ upper_weight = coordinate - lower
62
+ lower_weight = coordinate.dtype.type(1) - upper_weight
63
+ index = lower.astype(jnp.int32)
64
+ return [(index, lower_weight), (index + 1, upper_weight)]
65
+
66
+
67
+ def _cubic_indices_and_weights(coordinate: Array) -> list[tuple[Array, Array]]:
68
+ return _spline_point(None, coordinate)
69
+
70
+
71
+ def _build_matrix(n: int, diag: float = 4) -> Array:
72
+ M = jnp.zeros((n, n))
73
+ idx = jnp.arange(n)
74
+ M = M.at[idx, idx].set(diag)
75
+ M = M.at[idx[:-1], idx[:-1] + 1].set(1)
76
+ M = M.at[idx[:-1] + 1, idx[:-1]].set(1)
77
+ return M
78
+
79
+
80
+ def _construct_vector(data: Array, c2: Array, cnp2: Array) -> Array:
81
+ n = data.shape[0]
82
+ d = jnp.zeros(n)
83
+ d = d.at[0].set(6 * data[0] - c2)
84
+ d = d.at[-1].set(6 * data[-1] - cnp2)
85
+ d = d.at[1:-1].set(6 * data[1:-1])
86
+ return d
87
+
88
+
89
+ def _solve_coefficients(data: Array, A_inv: Array, h=1) -> Array:
90
+ n = data.shape[0]
91
+ finite_diff = jnp.diff(data, axis=0) / h
92
+ c2 = 0.0
93
+ cnp2 = 0.0
94
+ d = vmap(_construct_vector, in_axes=(1, None, None))(finite_diff, c2, cnp2)
95
+ c = vmap(jnp.dot, in_axes=(None, 0))(A_inv, d).T
96
+ c = jnp.concatenate(
97
+ [jnp.zeros((1, c.shape[1])), c, jnp.zeros((1, c.shape[1]))], axis=0
98
+ )
99
+ return c
100
+
101
+
102
+ def _spline_coefficients(data: Array) -> Array:
103
+ n = data.shape[0]
104
+ A = _build_matrix(n - 2)
105
+ A_inv = inv(A)
106
+ coefficients = _solve_coefficients(data, A_inv)
107
+ return coefficients
108
+
109
+
110
+ def _spline_basis(t: Array) -> Array:
111
+ abs_t = jnp.abs(t)
112
+ return jnp.where(
113
+ abs_t <= 1,
114
+ 2 / 3 - abs_t**2 + abs_t**3 / 2,
115
+ jnp.where(abs_t <= 2, (2 - abs_t) ** 3 / 6, 0.0),
116
+ )
117
+
118
+
119
+ def _spline_value(coefficients: Array, coordinate: Array, indexes: Array) -> Array:
120
+ t = coordinate - indexes
121
+ weights = _spline_basis(t)
122
+ return jnp.sum(coefficients * weights, axis=0)
123
+
124
+
125
+ def _spline_point(coefficients: Array, coordinate: Array) -> Array:
126
+ idx = jnp.floor(coordinate).astype(jnp.int32)
127
+ indexes = jnp.array([idx - 1, idx, idx + 1, idx + 2])
128
+ t = coordinate - indexes
129
+ weights = _spline_basis(t)
130
+ return [(i, w) for i, w in zip(indexes, weights)]
131
+
132
+
133
+ def _cubic_spline(input: Array, coordinates: Array) -> Array:
134
+ coefficients = _spline_coefficients(input)
135
+ indexes = jnp.arange(input.shape[0])
136
+ return vmap(_spline_value, in_axes=(None, 0, None))(
137
+ coefficients, coordinates, indexes
138
+ )
139
+
140
+
141
+ def _map_coordinates(
142
+ input: ArrayLike,
143
+ coordinates: Sequence[ArrayLike],
144
+ order: int,
145
+ mode: str,
146
+ cval: ArrayLike,
147
+ ) -> Array:
148
+ input_arr = jnp.asarray(input)
149
+ coordinates_arr = [jnp.asarray(c, dtype=input_arr.dtype) for c in coordinates]
150
+ cval = jnp.asarray(cval, input_arr.dtype)
151
+
152
+ if len(coordinates_arr) != input_arr.ndim:
153
+ raise ValueError(
154
+ f"coordinates must be a sequence of length input.ndim = {input_arr.ndim}, "
155
+ f"got {len(coordinates_arr)}"
156
+ )
157
+
158
+ index_fixer = _INDEX_FIXERS.get(mode)
159
+ if index_fixer is None:
160
+ raise NotImplementedError(
161
+ f"jax.scipy.ndimage.map_coordinates does not support mode {mode!r}"
162
+ )
163
+
164
+ if mode == "constant":
165
+ is_valid = lambda index, size: (index >= 0) & (index < size)
166
+ else:
167
+ is_valid = lambda index, size: True
168
+
169
+ if order == 0:
170
+ interp_fun = _nearest_indices_and_weights
171
+ elif order == 1:
172
+ interp_fun = _linear_indices_and_weights
173
+ elif order == 3:
174
+ interp_fun = _cubic_indices_and_weights
175
+ else:
176
+ raise NotImplementedError(
177
+ f"jax.scipy.ndimage.map_coordinates only supports order 0, 1, or 3, got {order}"
178
+ )
179
+
180
+ valid_1d_interpolations = []
181
+ for coordinate, size in zip(coordinates_arr, input_arr.shape):
182
+ interp_nodes = interp_fun(coordinate)
183
+ valid_interp = []
184
+ for index, weight in interp_nodes:
185
+ fixed_index = index_fixer(index, size)
186
+ valid = is_valid(index, size)
187
+ valid_interp.append((fixed_index, weight, valid))
188
+ valid_1d_interpolations.append(valid_interp)
189
+
190
+ outputs = []
191
+ for items in (
192
+ util.safe_zip(*valid_1d_interpolations)
193
+ if input_arr.ndim > 1
194
+ else [valid_1d_interpolations[0]]
195
+ ):
196
+ if input_arr.ndim == 1:
197
+ items = [items]
198
+
199
+ indices = []
200
+ validities = []
201
+ weights = []
202
+ for index, weight, valid in items:
203
+ indices.append(index)
204
+ weights.append(weight)
205
+ validities.append(valid)
206
+
207
+ if all(googles is True for googles in validities):
208
+ contribution = input_arr[tuple(indices)]
209
+ else:
210
+ all_valid = _nonempty_prod(validities) if validities else True
211
+ contribution = jnp.where(all_valid, input_arr[tuple(indices)], cval)
212
+
213
+ outputs.append(_nonempty_prod(weights) * contribution)
214
+
215
+ result = _nonempty_sum(outputs)
216
+ if jnp.issubdtype(input_arr.dtype, jnp.integer):
217
+ result = _round_half_away_from_zero(result)
218
+ return result
219
+
220
+
221
+ def map_coordinates(
222
+ input: ArrayLike,
223
+ coordinates: Sequence[ArrayLike],
224
+ order: int,
225
+ mode: str = "constant",
226
+ cval: ArrayLike = 0.0,
227
+ ) -> Array:
228
+ """Map coordinates using cubic spline interpolation.
229
+
230
+ Args:
231
+ input: The input array.
232
+ coordinates: Sequence of coordinate arrays for each dimension.
233
+ order: Interpolation order (0=nearest, 1=linear, 3=cubic spline).
234
+ mode: Boundary handling ('constant', 'nearest', 'wrap', 'mirror',
235
+ 'reflect').
236
+ cval: Value for 'constant' mode outside boundaries.
237
+
238
+ Returns:
239
+ Interpolated values at the given coordinates.
240
+ """
241
+ return api.jit(_map_coordinates, static_argnums=(2, 3))(
242
+ input, coordinates, order, mode, cval
243
+ )
@@ -0,0 +1,119 @@
1
+ """Image transformation utilities.
2
+
3
+ Flux-conserving resampling and sub-pixel image operations. All functions
4
+ are JIT-compilable and differentiable.
5
+ """
6
+
7
+ import functools
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+
12
+ from hwoutils.map_coordinates import map_coordinates
13
+
14
+
15
+ def ccw_rotation_matrix(rotation_deg: float) -> jax.Array:
16
+ """Return the counter-clockwise rotation matrix for a given angle.
17
+
18
+ Args:
19
+ rotation_deg: Rotation angle in degrees. Positive = counter-clockwise.
20
+
21
+ Returns:
22
+ 2x2 rotation matrix as a JAX array.
23
+ """
24
+ theta = jnp.deg2rad(rotation_deg)
25
+ cos_theta = jnp.cos(theta)
26
+ sin_theta = jnp.sin(theta)
27
+ return jnp.array(
28
+ [
29
+ [cos_theta, -sin_theta],
30
+ [sin_theta, cos_theta],
31
+ ]
32
+ )
33
+
34
+
35
+ @functools.partial(jax.jit, static_argnames=["order", "mode"])
36
+ def shift_image(
37
+ image: jax.Array,
38
+ shift_y: float,
39
+ shift_x: float,
40
+ order: int = 3,
41
+ mode: str = "constant",
42
+ cval: float = 0.0,
43
+ ) -> jax.Array:
44
+ """Shift an image with sub-pixel precision using cubic splines.
45
+
46
+ Uses inverse mapping: to shift content by (+dy, +dx), sample from
47
+ (y-dy, x-dx).
48
+
49
+ Args:
50
+ image: 2D input image.
51
+ shift_y: Shift in Y direction (pixels). Positive = Down.
52
+ shift_x: Shift in X direction (pixels). Positive = Right.
53
+ order: Interpolation order (3 = cubic splines).
54
+ mode: Boundary handling mode.
55
+ cval: Value for 'constant' mode outside boundaries.
56
+
57
+ Returns:
58
+ Shifted image with same shape as input.
59
+ """
60
+ ny, nx = image.shape
61
+ y_grid, x_grid = jnp.mgrid[:ny, :nx]
62
+ coords = [y_grid - shift_y, x_grid - shift_x]
63
+ return map_coordinates(image, coords, order=order, mode=mode, cval=cval)
64
+
65
+
66
+ @functools.partial(jax.jit, static_argnames=["shape_tgt"])
67
+ def resample_flux(
68
+ f_src: jax.Array,
69
+ pixscale_src: float,
70
+ pixscale_tgt: float,
71
+ shape_tgt: tuple[int, int],
72
+ rotation_deg: float = 0.0,
73
+ ) -> jax.Array:
74
+ """Resample an image onto a new grid while conserving total flux.
75
+
76
+ Performs an affine transformation (rotation and scaling) to map
77
+ the source image onto a target grid. Converts to surface brightness,
78
+ interpolates, then converts back to integrated flux per pixel.
79
+
80
+ Args:
81
+ f_src: Source image (2D) with integrated flux per pixel.
82
+ pixscale_src: Pixel scale of source image.
83
+ pixscale_tgt: Pixel scale of target image (same units as src).
84
+ shape_tgt: Target shape (ny_tgt, nx_tgt).
85
+ rotation_deg: CCW rotation angle in degrees.
86
+
87
+ Returns:
88
+ Resampled image with total flux conserved. Shape: (ny_tgt, nx_tgt).
89
+ """
90
+ ny_src, nx_src = f_src.shape
91
+ ny_tgt, nx_tgt = shape_tgt
92
+
93
+ # Surface brightness (flux per unit area)
94
+ s_src = f_src / (pixscale_src**2)
95
+
96
+ # Affine matrix (TARGET pixel centres -> SOURCE coordinates)
97
+ scale = pixscale_tgt / pixscale_src
98
+ a_mat = ccw_rotation_matrix(rotation_deg) * scale
99
+
100
+ c_src = jnp.array([(ny_src - 1) / 2.0, (nx_src - 1) / 2.0])
101
+ c_tgt = jnp.array([(ny_tgt - 1) / 2.0, (nx_tgt - 1) / 2.0])
102
+ offset = c_src - a_mat @ c_tgt
103
+
104
+ # Grid of TARGET pixel centres
105
+ y_coords = jnp.arange(ny_tgt)
106
+ x_coords = jnp.arange(nx_tgt)
107
+ y_tgt, x_tgt = jnp.meshgrid(y_coords, x_coords, indexing="ij")
108
+
109
+ # (2, ny_tgt, nx_tgt)
110
+ coords = jnp.stack([y_tgt, x_tgt], axis=0)
111
+ coords_src = (a_mat @ coords.reshape(2, -1) + offset[:, None]).reshape(coords.shape)
112
+
113
+ # Interpolate surface brightness
114
+ s_tgt = map_coordinates(
115
+ s_src, [coords_src[0], coords_src[1]], order=3, mode="constant", cval=0.0
116
+ )
117
+
118
+ # Back to integrated flux per target pixel
119
+ return s_tgt * (pixscale_tgt**2)