spectracles 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Thomas Hilder
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,73 @@
1
+ Metadata-Version: 2.4
2
+ Name: spectracles
3
+ Version: 0.5.0
4
+ Summary: Unified spectrospatial models: glasses for your spectra.
5
+ Author-email: Thomas Hilder <Thomas.Hilder@monash.edu>, "Andrew R. Casey" <Andrew.Casey@monash.edu>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.13
8
+ Requires-Dist: equinox>=0.13.0
9
+ Requires-Dist: jax-finufft
10
+ Requires-Dist: jax[cpu]>=0.7.0
11
+ Requires-Dist: matplotlib>=3.10.5
12
+ Requires-Dist: networkx>=3.5
13
+ Requires-Dist: tqdm>=4.67.1
14
+ Description-Content-Type: text/markdown
15
+
16
+ <div id="top"></div>
17
+
18
+ <!-- PROJECT SHIELDS -->
19
+ <!-- [![PyPI Package][pypi-shield]][pypi-url] -
20
+ <!-- [![JOSS][JOSS-shield]][JOSS-url] -->
21
+
22
+ <!-- ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/TomHilder/wakeflow/Tests.yml?label=tests&style=flat-square)
23
+ [![Docs][docs-status-shield]][docs-status-url] -->
24
+
25
+ <!-- PROJECT LOGO -->
26
+ <br />
27
+ <div align="center">
28
+ <a href="https://github.com/TomHilder/spectracles">
29
+ <img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectrackles" width="420">
30
+ </a>
31
+
32
+ <!-- <h3 align="center">Wakeflow</h3> -->
33
+
34
+ <p align="center">
35
+ Unified spectrospatial models for integral field spectroscopy in jax
36
+ </p>
37
+ </div>
38
+
39
+ <!-- <div align="center">
40
+ <img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectracles" width="420"></img>
41
+ </div> -->
42
+
43
+ ## Glasses for your spectra
44
+
45
+ Spectracles is a Python library for inferring properties of IFU/IFS spectra as continuous functions of sky position.
46
+
47
+ It can also be used as a general-purpose statistical model library that extends [`equinox`](https://github.com/patrick-kidger/equinox) to allow for composable models that may have *coupled* parameters. It also implements some other nice features that are a bit awkward in `equinox` out of the box, like easily updating model parameters between fixed and varying.
48
+
49
+ ## Installation
50
+
51
+ TODO
52
+
53
+ ## Usage
54
+
55
+ TODO
56
+
57
+ ## Citation
58
+
59
+ TODO
60
+
61
+ ## Help
62
+
63
+ TODO
64
+
65
+ ### TODO
66
+
67
+ - [x] Instead of replacing shared leaves with `0`, replace with some class/object instead
68
+ - [ ] Nicer `__repr__` for `ShareModule` that actually says the memory address
69
+ - [ ] Add memory address to the top of `print_model_tree`
70
+ - [ ] Support tuples, lists and dicts of models as attributes of models
71
+ - [ ] Handle non-odd number of modes
72
+ - [ ] Write better tests
73
+ - [ ] Rigorously type check the tests
@@ -0,0 +1,58 @@
1
+ <div id="top"></div>
2
+
3
+ <!-- PROJECT SHIELDS -->
4
+ <!-- [![PyPI Package][pypi-shield]][pypi-url] -
5
+ <!-- [![JOSS][JOSS-shield]][JOSS-url] -->
6
+
7
+ <!-- ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/TomHilder/wakeflow/Tests.yml?label=tests&style=flat-square)
8
+ [![Docs][docs-status-shield]][docs-status-url] -->
9
+
10
+ <!-- PROJECT LOGO -->
11
+ <br />
12
+ <div align="center">
13
+ <a href="https://github.com/TomHilder/spectracles">
14
+ <img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectrackles" width="420">
15
+ </a>
16
+
17
+ <!-- <h3 align="center">Wakeflow</h3> -->
18
+
19
+ <p align="center">
20
+ Unified spectrospatial models for integral field spectroscopy in jax
21
+ </p>
22
+ </div>
23
+
24
+ <!-- <div align="center">
25
+ <img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectracles" width="420"></img>
26
+ </div> -->
27
+
28
+ ## Glasses for your spectra
29
+
30
+ Spectracles is a Python library for inferring properties of IFU/IFS spectra as continuous functions of sky position.
31
+
32
+ It can also be used as a general-purpose statistical model library that extends [`equinox`](https://github.com/patrick-kidger/equinox) to allow for composable models that may have *coupled* parameters. It also implements some other nice features that are a bit awkward in `equinox` out of the box, like easily updating model parameters between fixed and varying.
33
+
34
+ ## Installation
35
+
36
+ TODO
37
+
38
+ ## Usage
39
+
40
+ TODO
41
+
42
+ ## Citation
43
+
44
+ TODO
45
+
46
+ ## Help
47
+
48
+ TODO
49
+
50
+ ### TODO
51
+
52
+ - [x] Instead of replacing shared leaves with `0`, replace with some class/object instead
53
+ - [ ] Nicer `__repr__` for `ShareModule` that actually says the memory address
54
+ - [ ] Add memory address to the top of `print_model_tree`
55
+ - [ ] Support tuples, lists and dicts of models as attributes of models
56
+ - [ ] Handle non-odd number of modes
57
+ - [ ] Write better tests
58
+ - [ ] Rigorously type check the tests
@@ -0,0 +1,47 @@
1
+ [project]
2
+ name = "spectracles"
3
+ version = "0.5.0"
4
+ description = "Unified spectrospatial models: glasses for your spectra."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Thomas Hilder", email = "Thomas.Hilder@monash.edu" },
8
+ { name = "Andrew R. Casey", email = "Andrew.Casey@monash.edu" },
9
+ ]
10
+ requires-python = ">=3.13"
11
+ dependencies = [
12
+ "equinox>=0.13.0",
13
+ "jax-finufft",
14
+ "jax[cpu]>=0.7.0",
15
+ "matplotlib>=3.10.5",
16
+ "networkx>=3.5",
17
+ "tqdm>=4.67.1",
18
+ ]
19
+
20
+ [build-system]
21
+ requires = ["hatchling"]
22
+ build-backend = "hatchling.build"
23
+
24
+ [tool.hatch.build]
25
+ include = ["src/spectracles/py.typed"]
26
+ exclude = ["**/__pycache__/**"]
27
+
28
+ [tool.hatch.build.targets.sdist]
29
+ include = ["src/**", "README.md", "LICENSE"]
30
+
31
+ [tool.hatch.build.targets.wheel]
32
+ packages = ["src/spectracles"]
33
+
34
+ [tool.ruff]
35
+ line-length = 99 # Sorry PEP8 I'm a rebel
36
+
37
+ [tool.uv.sources]
38
+ jax-finufft = { git = "https://github.com/flatironinstitute/jax-finufft.git" }
39
+
40
+ [dependency-groups]
41
+ dev = [
42
+ "mypy>=1.17.1",
43
+ "pytest>=8.4.1",
44
+ "pytest-cov>=6.2.1",
45
+ "ruff>=0.12.8",
46
+ "types-tqdm>=4.67.0.20250809",
47
+ ]
@@ -0,0 +1,35 @@
1
+ from spectracles.model.data import SpatialDataGeneric, SpatialDataLVM
2
+ from spectracles.model.io import load_model, save_model
3
+ from spectracles.model.kernels import Kernel, Matern12, Matern32, Matern52, SquaredExponential
4
+ from spectracles.model.parameter import AnyParameter, ConstrainedParameter, Parameter, l_bounded
5
+ from spectracles.model.share_module import build_model
6
+ from spectracles.model.spatial import FourierGP, PerSpaxel, SpatialModel
7
+ from spectracles.model.spectral import Constant, Gaussian, SpectralSpatialModel
8
+ from spectracles.optimise.opt_frame import OptimiserFrame
9
+ from spectracles.optimise.opt_schedule import OptimiserSchedule, PhaseConfig
10
+
11
+ __all__ = [
12
+ "FourierGP",
13
+ "Gaussian",
14
+ "SpatialDataGeneric",
15
+ "SpatialDataLVM",
16
+ "PerSpaxel",
17
+ "Constant",
18
+ "Kernel",
19
+ "Matern12",
20
+ "Matern32",
21
+ "Matern52",
22
+ "SquaredExponential",
23
+ "build_model",
24
+ "SpatialModel",
25
+ "SpectralSpatialModel",
26
+ "Parameter",
27
+ "ConstrainedParameter",
28
+ "AnyParameter",
29
+ "l_bounded",
30
+ "OptimiserFrame",
31
+ "save_model",
32
+ "load_model",
33
+ "PhaseConfig",
34
+ "OptimiserSchedule",
35
+ ]
@@ -0,0 +1 @@
1
+ """model - subpackage for modelling framework built on jax and equinox."""
@@ -0,0 +1,27 @@
1
+ """data.py - Data structures used as arguments for model evaluations/predictions."""
2
+
3
+ import jax.numpy as jnp
4
+ from equinox import Module, field
5
+ from jaxtyping import Array
6
+
7
+
8
+ def convert_to_flat_array(array: Array) -> Array:
9
+ return jnp.asarray(array).flatten()
10
+
11
+
12
+ class SpatialDataGeneric(Module):
13
+ x: Array = field(converter=convert_to_flat_array)
14
+ y: Array = field(converter=convert_to_flat_array)
15
+ idx: Array = field(converter=convert_to_flat_array)
16
+
17
+
18
+ class SpatialDataLVM(Module):
19
+ x: Array = field(converter=convert_to_flat_array)
20
+ y: Array = field(converter=convert_to_flat_array)
21
+ idx: Array = field(converter=convert_to_flat_array)
22
+ tile_idx: Array = field(converter=convert_to_flat_array)
23
+ ifu_idx: Array = field(converter=convert_to_flat_array)
24
+
25
+
26
+ # TODO: This is bad. We want users to be able to write their own SpatialDataFoo class, but with the current setup the typing doesn't strictly allow this. Really this should be a Protocol or subclassing sitation. Not worth refactoring right now.
27
+ SpatialData = SpatialDataGeneric | SpatialDataLVM
@@ -0,0 +1,84 @@
1
+ """graph.py - Graph utilities for the modelling framework."""
2
+
3
+ from contextlib import contextmanager
4
+
5
+ import matplotlib as mpl
6
+ from networkx import DiGraph
7
+
8
+ DEFAULT_NX_KWDS = {
9
+ "node_size": 8_000,
10
+ "node_color": "white",
11
+ "edgecolors": "black",
12
+ "linewidths": 2,
13
+ "arrowsize": 20,
14
+ "font_size": 8,
15
+ "font_color": "black",
16
+ }
17
+
18
+
19
+ @contextmanager
20
+ def temporarily_disable_tex():
21
+ prev_setting = mpl.rcParams["text.usetex"]
22
+ mpl.rcParams["text.usetex"] = False
23
+ try:
24
+ yield
25
+ finally:
26
+ mpl.rcParams["text.usetex"] = prev_setting
27
+
28
+
29
+ def print_graph(graph: DiGraph, root_id: int, indent: str = "", is_last: bool = True) -> None:
30
+ # Get info for current node
31
+ node_data = graph.nodes[root_id]
32
+ name = node_data["name"]
33
+ node_type = node_data["type"]
34
+ # Format display text
35
+ if name is None:
36
+ # Root module without parent attribute name
37
+ display_text = node_type
38
+ else:
39
+ # Regular format: "name (Type)"
40
+ display_text = f"{name} ({node_type})"
41
+ # Print this node
42
+ print(f"{indent}{'└── ' if is_last else '├── '}{display_text}")
43
+ # Find child nodes (those pointing to this node)
44
+ children = []
45
+ for src, dst in graph.edges():
46
+ if src == root_id:
47
+ children.append(dst)
48
+ # Recurse for each child
49
+ new_indent = indent + (" " if is_last else "│ ")
50
+ for i, child_id in enumerate(children):
51
+ is_last_child = i == len(children) - 1
52
+ print_graph(graph, child_id, new_indent, is_last_child)
53
+
54
+
55
+ def layered_hierarchy_pos(G, root, total_width=1.0, vert_gap=0.2):
56
+ """NOTE: AI generated function"""
57
+ from collections import defaultdict, deque
58
+
59
+ levels = defaultdict(list)
60
+ visited = set()
61
+ queue = deque([(root, 0)])
62
+ max_level = 0
63
+
64
+ while queue:
65
+ node, level = queue.popleft()
66
+ if node in visited:
67
+ continue
68
+ visited.add(node)
69
+ levels[level].append(node)
70
+ max_level = max(max_level, level)
71
+ for child in G.successors(node):
72
+ queue.append((child, level + 1))
73
+
74
+ pos = {}
75
+ for level in range(max_level + 1):
76
+ nodes = levels[level]
77
+ n = len(nodes)
78
+ gap = total_width / (n + 1)
79
+ for i, node in enumerate(nodes):
80
+ x = (i + 1) * gap
81
+ y = -level * vert_gap
82
+ pos[node] = (x, y)
83
+
84
+ return pos
@@ -0,0 +1,78 @@
1
+ """io.py - Model object serialisation and deserialisation."""
2
+
3
+ from pathlib import Path
4
+ from warnings import catch_warnings, filterwarnings
5
+
6
+ from dill import PicklingWarning, dump, load # type: ignore[import]
7
+
8
+ from spectracles.model.share_module import ShareModule
9
+
10
+ MODELFILE_EXT = ".model"
11
+
12
+
13
+ def save_model(model: ShareModule, file: Path, overwrite: bool = False, **dump_kwargs):
14
+ """
15
+ Save a model to a file.
16
+
17
+ Args:
18
+ model (ShareModule): The model to save.
19
+ file (Path): The file to save the model to.
20
+ overwrite (bool): Whether to overwrite the file if it exists. Defaults to False.
21
+ dump_kwargs: Additional arguments to pass to dill.dump.
22
+
23
+ Raises:
24
+ FileExistsError: If the file already exists and overwrite is False.
25
+ TypeError: If the model is not of type ShareModule.
26
+ Exception: If there is a warning when saving the model.
27
+ """
28
+ # Ensure always the same file extension
29
+ file = file.with_suffix(MODELFILE_EXT)
30
+ # Check that the file doesn't exist already
31
+ if file.exists():
32
+ if not overwrite:
33
+ raise FileExistsError("File already exists. Overwrite with overwrite=True.")
34
+ # Check that model has the right type
35
+ if not isinstance(model, ShareModule):
36
+ raise TypeError(
37
+ "model must be type ShareModule. Saving sub-components of a model is not currenly supported. Saving a model not instantiated via build_model will never be supported."
38
+ )
39
+ # We want dill warnings to be exceptions
40
+ with catch_warnings():
41
+ filterwarnings("error")
42
+ # Open file in write and bytes mode
43
+ with open(file, "wb") as f:
44
+ try:
45
+ dump(model, f, **dump_kwargs)
46
+ except PicklingWarning:
47
+ raise Exception(
48
+ "Above warning raised by dill when saving. Likely, you need to move your model class into it's own file and import it."
49
+ )
50
+
51
+
52
+ def load_model(file: Path, **load_kwargs) -> ShareModule:
53
+ """
54
+ Load a model from a file.
55
+
56
+ Args:
57
+ file (Path): The file to load the model from.
58
+ load_kwargs: Additional arguments to pass to dill.load.
59
+
60
+ Returns:
61
+ ShareModule: The loaded model.
62
+
63
+ Raises:
64
+ FileNotFoundError: If the file does not exist.
65
+ TypeError: If the loaded model is not of type ShareModule.
66
+ """
67
+ # Ensure always the same file extension
68
+ file = file.with_suffix(MODELFILE_EXT)
69
+ # Check that the file exists
70
+ if not file.exists():
71
+ raise FileNotFoundError(f"File {file} does not exist.")
72
+ # Open file in read and bytes mode
73
+ with open(file, "rb") as f:
74
+ model = load(f, **load_kwargs)
75
+ # Check that the model is of the right type
76
+ if not isinstance(model, ShareModule):
77
+ raise TypeError("Loaded from file successfully, but model is not of type ShareModule.")
78
+ return model
@@ -0,0 +1,90 @@
1
+ """kernels.py - Kernel classes that implement various covariance functions expressed by their power spectral density or 'feature weights' for use with Fourier-accelerated GP models."""
2
+
3
+ from abc import abstractmethod
4
+
5
+ import jax.numpy as jnp
6
+ from equinox import Module
7
+ from jaxtyping import Array, ArrayLike
8
+
9
+ from spectracles.model.parameter import Parameter
10
+
11
+
12
+ def normalise_fw(fw: ArrayLike) -> Array:
13
+ """
14
+ Normalise the feature weights. Includes a factor of sqrt(2) which accounts for the
15
+ halving in total power incurred by enforcing the Fourier coefficients to be real via
16
+ conjugate symmetry.
17
+ """
18
+ power = jnp.sum(jnp.abs(fw) ** 2)
19
+ return jnp.sqrt(2) * fw / jnp.sqrt(power)
20
+
21
+
22
+ def matern_kernel_fw_nd(
23
+ freqs: ArrayLike,
24
+ length: ArrayLike,
25
+ var: ArrayLike,
26
+ nu: float,
27
+ n: int,
28
+ ) -> Array:
29
+ """Square root of the PSD of the Matern kernel in n dimensions."""
30
+ fw = (1 + (freqs * length) ** 2) ** (-0.5 * (nu + n / 2))
31
+ return jnp.sqrt(var) * normalise_fw(fw)
32
+
33
+
34
+ class Kernel(Module):
35
+ # All kernels should have a length scale and variance
36
+ length_scale: Parameter
37
+ variance: Parameter
38
+
39
+ @abstractmethod
40
+ def feature_weights(self, freqs: Array) -> Array:
41
+ pass
42
+
43
+
44
+ class Matern12(Kernel):
45
+ length_scale: Parameter
46
+ variance: Parameter
47
+
48
+ def __init__(self, length_scale: Parameter, variance: Parameter):
49
+ self.length_scale = length_scale
50
+ self.variance = variance
51
+
52
+ def feature_weights(self, freqs: Array) -> Array:
53
+ return matern_kernel_fw_nd(freqs, self.length_scale.val, self.variance.val, nu=0.5, n=2)
54
+
55
+
56
+ class Matern32(Kernel):
57
+ length_scale: Parameter
58
+ variance: Parameter
59
+
60
+ def __init__(self, length_scale: Parameter, variance: Parameter):
61
+ self.length_scale = length_scale
62
+ self.variance = variance
63
+
64
+ def feature_weights(self, freqs: Array) -> Array:
65
+ return matern_kernel_fw_nd(freqs, self.length_scale.val, self.variance.val, nu=1.5, n=2)
66
+
67
+
68
+ class Matern52(Kernel):
69
+ length_scale: Parameter
70
+ variance: Parameter
71
+
72
+ def __init__(self, length_scale: Parameter, variance: Parameter):
73
+ self.length_scale = length_scale
74
+ self.variance = variance
75
+
76
+ def feature_weights(self, freqs: Array) -> Array:
77
+ return matern_kernel_fw_nd(freqs, self.length_scale.val, self.variance.val, nu=2.5, n=2)
78
+
79
+
80
+ class SquaredExponential(Kernel):
81
+ length_scale: Parameter
82
+ variance: Parameter
83
+
84
+ def __init__(self, length_scale: Parameter, variance: Parameter):
85
+ self.length_scale = length_scale
86
+ self.variance = variance
87
+
88
+ def feature_weights(self, freqs: Array) -> Array:
89
+ fw = jnp.exp(-0.25 * freqs**2 * self.length_scale.val**2 + 1e-4)
90
+ return jnp.sqrt(self.variance.val) * normalise_fw(fw)