spectracles 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spectracles-0.5.0/LICENSE +21 -0
- spectracles-0.5.0/PKG-INFO +73 -0
- spectracles-0.5.0/README.md +58 -0
- spectracles-0.5.0/pyproject.toml +47 -0
- spectracles-0.5.0/src/spectracles/__init__.py +35 -0
- spectracles-0.5.0/src/spectracles/model/__init__.py +1 -0
- spectracles-0.5.0/src/spectracles/model/data.py +27 -0
- spectracles-0.5.0/src/spectracles/model/graph.py +84 -0
- spectracles-0.5.0/src/spectracles/model/io.py +78 -0
- spectracles-0.5.0/src/spectracles/model/kernels.py +90 -0
- spectracles-0.5.0/src/spectracles/model/parameter.py +226 -0
- spectracles-0.5.0/src/spectracles/model/share_module.py +358 -0
- spectracles-0.5.0/src/spectracles/model/spatial.py +148 -0
- spectracles-0.5.0/src/spectracles/model/spectral.py +43 -0
- spectracles-0.5.0/src/spectracles/optimise/__init__.py +1 -0
- spectracles-0.5.0/src/spectracles/optimise/opt_frame.py +175 -0
- spectracles-0.5.0/src/spectracles/optimise/opt_schedule.py +359 -0
- spectracles-0.5.0/src/spectracles/py.typed +0 -0
- spectracles-0.5.0/src/spectracles/tree/__init__.py +1 -0
- spectracles-0.5.0/src/spectracles/tree/path_utils.py +90 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Thomas Hilder
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: spectracles
|
|
3
|
+
Version: 0.5.0
|
|
4
|
+
Summary: Unified spectrospatial models: glasses for your spectra.
|
|
5
|
+
Author-email: Thomas Hilder <Thomas.Hilder@monash.edu>, "Andrew R. Casey" <Andrew.Casey@monash.edu>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.13
|
|
8
|
+
Requires-Dist: equinox>=0.13.0
|
|
9
|
+
Requires-Dist: jax-finufft
|
|
10
|
+
Requires-Dist: jax[cpu]>=0.7.0
|
|
11
|
+
Requires-Dist: matplotlib>=3.10.5
|
|
12
|
+
Requires-Dist: networkx>=3.5
|
|
13
|
+
Requires-Dist: tqdm>=4.67.1
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
<div id="top"></div>
|
|
17
|
+
|
|
18
|
+
<!-- PROJECT SHIELDS -->
|
|
19
|
+
<!-- [![PyPI Package][pypi-shield]][pypi-url] -
|
|
20
|
+
<!-- [![JOSS][JOSS-shield]][JOSS-url] -->
|
|
21
|
+
|
|
22
|
+
<!-- 
|
|
23
|
+
[![Docs][docs-status-shield]][docs-status-url] -->
|
|
24
|
+
|
|
25
|
+
<!-- PROJECT LOGO -->
|
|
26
|
+
<br />
|
|
27
|
+
<div align="center">
|
|
28
|
+
<a href="https://github.com/TomHilder/spectracles">
|
|
29
|
+
<img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectrackles" width="420">
|
|
30
|
+
</a>
|
|
31
|
+
|
|
32
|
+
<!-- <h3 align="center">Wakeflow</h3> -->
|
|
33
|
+
|
|
34
|
+
<p align="center">
|
|
35
|
+
Unified spectrospatial models for integral field spectroscopy in jax
|
|
36
|
+
</p>
|
|
37
|
+
</div>
|
|
38
|
+
|
|
39
|
+
<!-- <div align="center">
|
|
40
|
+
<img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectracles" width="420"></img>
|
|
41
|
+
</div> -->
|
|
42
|
+
|
|
43
|
+
## Glasses for your spectra
|
|
44
|
+
|
|
45
|
+
Spectracles is a Python library for inferring properties of IFU/IFS spectra as continuous functions of sky position.
|
|
46
|
+
|
|
47
|
+
It can also be used as a general-purpose statistical model library that extends [`equinox`](https://github.com/patrick-kidger/equinox) to allow for composable models that may have *coupled* parameters. It also implements some other nice features that are a bit awkward in `equinox` out of the box, like easily updating model parameters between fixed and varying.
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
TODO
|
|
52
|
+
|
|
53
|
+
## Usage
|
|
54
|
+
|
|
55
|
+
TODO
|
|
56
|
+
|
|
57
|
+
## Citation
|
|
58
|
+
|
|
59
|
+
TODO
|
|
60
|
+
|
|
61
|
+
## Help
|
|
62
|
+
|
|
63
|
+
TODO
|
|
64
|
+
|
|
65
|
+
### TODO
|
|
66
|
+
|
|
67
|
+
- [x] Instead of replacing shared leaves with `0`, replace with some class/object instead
|
|
68
|
+
- [ ] Nicer `__repr__` for `ShareModule` that actually says the memory address
|
|
69
|
+
- [ ] Add memory address to the top of `print_model_tree`
|
|
70
|
+
- [ ] Support tuples, lists and dicts of models as attributes of models
|
|
71
|
+
- [ ] Handle non-odd number of modes
|
|
72
|
+
- [ ] Write better tests
|
|
73
|
+
- [ ] Rigorously type check the tests
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
<div id="top"></div>
|
|
2
|
+
|
|
3
|
+
<!-- PROJECT SHIELDS -->
|
|
4
|
+
<!-- [![PyPI Package][pypi-shield]][pypi-url] -
|
|
5
|
+
<!-- [![JOSS][JOSS-shield]][JOSS-url] -->
|
|
6
|
+
|
|
7
|
+
<!-- 
|
|
8
|
+
[![Docs][docs-status-shield]][docs-status-url] -->
|
|
9
|
+
|
|
10
|
+
<!-- PROJECT LOGO -->
|
|
11
|
+
<br />
|
|
12
|
+
<div align="center">
|
|
13
|
+
<a href="https://github.com/TomHilder/spectracles">
|
|
14
|
+
<img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectrackles" width="420">
|
|
15
|
+
</a>
|
|
16
|
+
|
|
17
|
+
<!-- <h3 align="center">Wakeflow</h3> -->
|
|
18
|
+
|
|
19
|
+
<p align="center">
|
|
20
|
+
Unified spectrospatial models for integral field spectroscopy in jax
|
|
21
|
+
</p>
|
|
22
|
+
</div>
|
|
23
|
+
|
|
24
|
+
<!-- <div align="center">
|
|
25
|
+
<img src="https://raw.githubusercontent.com/TomHilder/spectracles/main/logo.png" alt="spectracles" width="420"></img>
|
|
26
|
+
</div> -->
|
|
27
|
+
|
|
28
|
+
## Glasses for your spectra
|
|
29
|
+
|
|
30
|
+
Spectracles is a Python library for inferring properties of IFU/IFS spectra as continuous functions of sky position.
|
|
31
|
+
|
|
32
|
+
It can also be used as a general-purpose statistical model library that extends [`equinox`](https://github.com/patrick-kidger/equinox) to allow for composable models that may have *coupled* parameters. It also implements some other nice features that are a bit awkward in `equinox` out of the box, like easily updating model parameters between fixed and varying.
|
|
33
|
+
|
|
34
|
+
## Installation
|
|
35
|
+
|
|
36
|
+
TODO
|
|
37
|
+
|
|
38
|
+
## Usage
|
|
39
|
+
|
|
40
|
+
TODO
|
|
41
|
+
|
|
42
|
+
## Citation
|
|
43
|
+
|
|
44
|
+
TODO
|
|
45
|
+
|
|
46
|
+
## Help
|
|
47
|
+
|
|
48
|
+
TODO
|
|
49
|
+
|
|
50
|
+
### TODO
|
|
51
|
+
|
|
52
|
+
- [x] Instead of replacing shared leaves with `0`, replace with some class/object instead
|
|
53
|
+
- [ ] Nicer `__repr__` for `ShareModule` that actually says the memory address
|
|
54
|
+
- [ ] Add memory address to the top of `print_model_tree`
|
|
55
|
+
- [ ] Support tuples, lists and dicts of models as attributes of models
|
|
56
|
+
- [ ] Handle non-odd number of modes
|
|
57
|
+
- [ ] Write better tests
|
|
58
|
+
- [ ] Rigorously type check the tests
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "spectracles"
|
|
3
|
+
version = "0.5.0"
|
|
4
|
+
description = "Unified spectrospatial models: glasses for your spectra."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "Thomas Hilder", email = "Thomas.Hilder@monash.edu" },
|
|
8
|
+
{ name = "Andrew R. Casey", email = "Andrew.Casey@monash.edu" },
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.13"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"equinox>=0.13.0",
|
|
13
|
+
"jax-finufft",
|
|
14
|
+
"jax[cpu]>=0.7.0",
|
|
15
|
+
"matplotlib>=3.10.5",
|
|
16
|
+
"networkx>=3.5",
|
|
17
|
+
"tqdm>=4.67.1",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["hatchling"]
|
|
22
|
+
build-backend = "hatchling.build"
|
|
23
|
+
|
|
24
|
+
[tool.hatch.build]
|
|
25
|
+
include = ["src/spectracles/py.typed"]
|
|
26
|
+
exclude = ["**/__pycache__/**"]
|
|
27
|
+
|
|
28
|
+
[tool.hatch.build.targets.sdist]
|
|
29
|
+
include = ["src/**", "README.md", "LICENSE"]
|
|
30
|
+
|
|
31
|
+
[tool.hatch.build.targets.wheel]
|
|
32
|
+
packages = ["src/spectracles"]
|
|
33
|
+
|
|
34
|
+
[tool.ruff]
|
|
35
|
+
line-length = 99 # Sorry PEP8 I'm a rebel
|
|
36
|
+
|
|
37
|
+
[tool.uv.sources]
|
|
38
|
+
jax-finufft = { git = "https://github.com/flatironinstitute/jax-finufft.git" }
|
|
39
|
+
|
|
40
|
+
[dependency-groups]
|
|
41
|
+
dev = [
|
|
42
|
+
"mypy>=1.17.1",
|
|
43
|
+
"pytest>=8.4.1",
|
|
44
|
+
"pytest-cov>=6.2.1",
|
|
45
|
+
"ruff>=0.12.8",
|
|
46
|
+
"types-tqdm>=4.67.0.20250809",
|
|
47
|
+
]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from spectracles.model.data import SpatialDataGeneric, SpatialDataLVM
|
|
2
|
+
from spectracles.model.io import load_model, save_model
|
|
3
|
+
from spectracles.model.kernels import Kernel, Matern12, Matern32, Matern52, SquaredExponential
|
|
4
|
+
from spectracles.model.parameter import AnyParameter, ConstrainedParameter, Parameter, l_bounded
|
|
5
|
+
from spectracles.model.share_module import build_model
|
|
6
|
+
from spectracles.model.spatial import FourierGP, PerSpaxel, SpatialModel
|
|
7
|
+
from spectracles.model.spectral import Constant, Gaussian, SpectralSpatialModel
|
|
8
|
+
from spectracles.optimise.opt_frame import OptimiserFrame
|
|
9
|
+
from spectracles.optimise.opt_schedule import OptimiserSchedule, PhaseConfig
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"FourierGP",
|
|
13
|
+
"Gaussian",
|
|
14
|
+
"SpatialDataGeneric",
|
|
15
|
+
"SpatialDataLVM",
|
|
16
|
+
"PerSpaxel",
|
|
17
|
+
"Constant",
|
|
18
|
+
"Kernel",
|
|
19
|
+
"Matern12",
|
|
20
|
+
"Matern32",
|
|
21
|
+
"Matern52",
|
|
22
|
+
"SquaredExponential",
|
|
23
|
+
"build_model",
|
|
24
|
+
"SpatialModel",
|
|
25
|
+
"SpectralSpatialModel",
|
|
26
|
+
"Parameter",
|
|
27
|
+
"ConstrainedParameter",
|
|
28
|
+
"AnyParameter",
|
|
29
|
+
"l_bounded",
|
|
30
|
+
"OptimiserFrame",
|
|
31
|
+
"save_model",
|
|
32
|
+
"load_model",
|
|
33
|
+
"PhaseConfig",
|
|
34
|
+
"OptimiserSchedule",
|
|
35
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""model - subpackage for modelling framework built on jax and equinox."""
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""data.py - Data structures used as arguments for model evaluations/predictions."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from equinox import Module, field
|
|
5
|
+
from jaxtyping import Array
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_to_flat_array(array: Array) -> Array:
|
|
9
|
+
return jnp.asarray(array).flatten()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SpatialDataGeneric(Module):
|
|
13
|
+
x: Array = field(converter=convert_to_flat_array)
|
|
14
|
+
y: Array = field(converter=convert_to_flat_array)
|
|
15
|
+
idx: Array = field(converter=convert_to_flat_array)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SpatialDataLVM(Module):
|
|
19
|
+
x: Array = field(converter=convert_to_flat_array)
|
|
20
|
+
y: Array = field(converter=convert_to_flat_array)
|
|
21
|
+
idx: Array = field(converter=convert_to_flat_array)
|
|
22
|
+
tile_idx: Array = field(converter=convert_to_flat_array)
|
|
23
|
+
ifu_idx: Array = field(converter=convert_to_flat_array)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# TODO: This is bad. We want users to be able to write their own SpatialDataFoo class, but with the current setup the typing doesn't strictly allow this. Really this should be a Protocol or subclassing sitation. Not worth refactoring right now.
|
|
27
|
+
SpatialData = SpatialDataGeneric | SpatialDataLVM
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""graph.py - Graph utilities for the modelling framework."""
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
from networkx import DiGraph
|
|
7
|
+
|
|
8
|
+
DEFAULT_NX_KWDS = {
|
|
9
|
+
"node_size": 8_000,
|
|
10
|
+
"node_color": "white",
|
|
11
|
+
"edgecolors": "black",
|
|
12
|
+
"linewidths": 2,
|
|
13
|
+
"arrowsize": 20,
|
|
14
|
+
"font_size": 8,
|
|
15
|
+
"font_color": "black",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@contextmanager
|
|
20
|
+
def temporarily_disable_tex():
|
|
21
|
+
prev_setting = mpl.rcParams["text.usetex"]
|
|
22
|
+
mpl.rcParams["text.usetex"] = False
|
|
23
|
+
try:
|
|
24
|
+
yield
|
|
25
|
+
finally:
|
|
26
|
+
mpl.rcParams["text.usetex"] = prev_setting
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def print_graph(graph: DiGraph, root_id: int, indent: str = "", is_last: bool = True) -> None:
|
|
30
|
+
# Get info for current node
|
|
31
|
+
node_data = graph.nodes[root_id]
|
|
32
|
+
name = node_data["name"]
|
|
33
|
+
node_type = node_data["type"]
|
|
34
|
+
# Format display text
|
|
35
|
+
if name is None:
|
|
36
|
+
# Root module without parent attribute name
|
|
37
|
+
display_text = node_type
|
|
38
|
+
else:
|
|
39
|
+
# Regular format: "name (Type)"
|
|
40
|
+
display_text = f"{name} ({node_type})"
|
|
41
|
+
# Print this node
|
|
42
|
+
print(f"{indent}{'└── ' if is_last else '├── '}{display_text}")
|
|
43
|
+
# Find child nodes (those pointing to this node)
|
|
44
|
+
children = []
|
|
45
|
+
for src, dst in graph.edges():
|
|
46
|
+
if src == root_id:
|
|
47
|
+
children.append(dst)
|
|
48
|
+
# Recurse for each child
|
|
49
|
+
new_indent = indent + (" " if is_last else "│ ")
|
|
50
|
+
for i, child_id in enumerate(children):
|
|
51
|
+
is_last_child = i == len(children) - 1
|
|
52
|
+
print_graph(graph, child_id, new_indent, is_last_child)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def layered_hierarchy_pos(G, root, total_width=1.0, vert_gap=0.2):
|
|
56
|
+
"""NOTE: AI generated function"""
|
|
57
|
+
from collections import defaultdict, deque
|
|
58
|
+
|
|
59
|
+
levels = defaultdict(list)
|
|
60
|
+
visited = set()
|
|
61
|
+
queue = deque([(root, 0)])
|
|
62
|
+
max_level = 0
|
|
63
|
+
|
|
64
|
+
while queue:
|
|
65
|
+
node, level = queue.popleft()
|
|
66
|
+
if node in visited:
|
|
67
|
+
continue
|
|
68
|
+
visited.add(node)
|
|
69
|
+
levels[level].append(node)
|
|
70
|
+
max_level = max(max_level, level)
|
|
71
|
+
for child in G.successors(node):
|
|
72
|
+
queue.append((child, level + 1))
|
|
73
|
+
|
|
74
|
+
pos = {}
|
|
75
|
+
for level in range(max_level + 1):
|
|
76
|
+
nodes = levels[level]
|
|
77
|
+
n = len(nodes)
|
|
78
|
+
gap = total_width / (n + 1)
|
|
79
|
+
for i, node in enumerate(nodes):
|
|
80
|
+
x = (i + 1) * gap
|
|
81
|
+
y = -level * vert_gap
|
|
82
|
+
pos[node] = (x, y)
|
|
83
|
+
|
|
84
|
+
return pos
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""io.py - Model object serialisation and deserialisation."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from warnings import catch_warnings, filterwarnings
|
|
5
|
+
|
|
6
|
+
from dill import PicklingWarning, dump, load # type: ignore[import]
|
|
7
|
+
|
|
8
|
+
from spectracles.model.share_module import ShareModule
|
|
9
|
+
|
|
10
|
+
MODELFILE_EXT = ".model"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def save_model(model: ShareModule, file: Path, overwrite: bool = False, **dump_kwargs):
|
|
14
|
+
"""
|
|
15
|
+
Save a model to a file.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
model (ShareModule): The model to save.
|
|
19
|
+
file (Path): The file to save the model to.
|
|
20
|
+
overwrite (bool): Whether to overwrite the file if it exists. Defaults to False.
|
|
21
|
+
dump_kwargs: Additional arguments to pass to dill.dump.
|
|
22
|
+
|
|
23
|
+
Raises:
|
|
24
|
+
FileExistsError: If the file already exists and overwrite is False.
|
|
25
|
+
TypeError: If the model is not of type ShareModule.
|
|
26
|
+
Exception: If there is a warning when saving the model.
|
|
27
|
+
"""
|
|
28
|
+
# Ensure always the same file extension
|
|
29
|
+
file = file.with_suffix(MODELFILE_EXT)
|
|
30
|
+
# Check that the file doesn't exist already
|
|
31
|
+
if file.exists():
|
|
32
|
+
if not overwrite:
|
|
33
|
+
raise FileExistsError("File already exists. Overwrite with overwrite=True.")
|
|
34
|
+
# Check that model has the right type
|
|
35
|
+
if not isinstance(model, ShareModule):
|
|
36
|
+
raise TypeError(
|
|
37
|
+
"model must be type ShareModule. Saving sub-components of a model is not currenly supported. Saving a model not instantiated via build_model will never be supported."
|
|
38
|
+
)
|
|
39
|
+
# We want dill warnings to be exceptions
|
|
40
|
+
with catch_warnings():
|
|
41
|
+
filterwarnings("error")
|
|
42
|
+
# Open file in write and bytes mode
|
|
43
|
+
with open(file, "wb") as f:
|
|
44
|
+
try:
|
|
45
|
+
dump(model, f, **dump_kwargs)
|
|
46
|
+
except PicklingWarning:
|
|
47
|
+
raise Exception(
|
|
48
|
+
"Above warning raised by dill when saving. Likely, you need to move your model class into it's own file and import it."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_model(file: Path, **load_kwargs) -> ShareModule:
|
|
53
|
+
"""
|
|
54
|
+
Load a model from a file.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
file (Path): The file to load the model from.
|
|
58
|
+
load_kwargs: Additional arguments to pass to dill.load.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
ShareModule: The loaded model.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
FileNotFoundError: If the file does not exist.
|
|
65
|
+
TypeError: If the loaded model is not of type ShareModule.
|
|
66
|
+
"""
|
|
67
|
+
# Ensure always the same file extension
|
|
68
|
+
file = file.with_suffix(MODELFILE_EXT)
|
|
69
|
+
# Check that the file exists
|
|
70
|
+
if not file.exists():
|
|
71
|
+
raise FileNotFoundError(f"File {file} does not exist.")
|
|
72
|
+
# Open file in read and bytes mode
|
|
73
|
+
with open(file, "rb") as f:
|
|
74
|
+
model = load(f, **load_kwargs)
|
|
75
|
+
# Check that the model is of the right type
|
|
76
|
+
if not isinstance(model, ShareModule):
|
|
77
|
+
raise TypeError("Loaded from file successfully, but model is not of type ShareModule.")
|
|
78
|
+
return model
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""kernels.py - Kernel classes that implement various covariance functions expressed by their power spectral density or 'feature weights' for use with Fourier-accelerated GP models."""
|
|
2
|
+
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from equinox import Module
|
|
7
|
+
from jaxtyping import Array, ArrayLike
|
|
8
|
+
|
|
9
|
+
from spectracles.model.parameter import Parameter
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def normalise_fw(fw: ArrayLike) -> Array:
|
|
13
|
+
"""
|
|
14
|
+
Normalise the feature weights. Includes a factor of sqrt(2) which accounts for the
|
|
15
|
+
halving in total power incurred by enforcing the Fourier coefficients to be real via
|
|
16
|
+
conjugate symmetry.
|
|
17
|
+
"""
|
|
18
|
+
power = jnp.sum(jnp.abs(fw) ** 2)
|
|
19
|
+
return jnp.sqrt(2) * fw / jnp.sqrt(power)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def matern_kernel_fw_nd(
|
|
23
|
+
freqs: ArrayLike,
|
|
24
|
+
length: ArrayLike,
|
|
25
|
+
var: ArrayLike,
|
|
26
|
+
nu: float,
|
|
27
|
+
n: int,
|
|
28
|
+
) -> Array:
|
|
29
|
+
"""Square root of the PSD of the Matern kernel in n dimensions."""
|
|
30
|
+
fw = (1 + (freqs * length) ** 2) ** (-0.5 * (nu + n / 2))
|
|
31
|
+
return jnp.sqrt(var) * normalise_fw(fw)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Kernel(Module):
|
|
35
|
+
# All kernels should have a length scale and variance
|
|
36
|
+
length_scale: Parameter
|
|
37
|
+
variance: Parameter
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def feature_weights(self, freqs: Array) -> Array:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Matern12(Kernel):
|
|
45
|
+
length_scale: Parameter
|
|
46
|
+
variance: Parameter
|
|
47
|
+
|
|
48
|
+
def __init__(self, length_scale: Parameter, variance: Parameter):
|
|
49
|
+
self.length_scale = length_scale
|
|
50
|
+
self.variance = variance
|
|
51
|
+
|
|
52
|
+
def feature_weights(self, freqs: Array) -> Array:
|
|
53
|
+
return matern_kernel_fw_nd(freqs, self.length_scale.val, self.variance.val, nu=0.5, n=2)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Matern32(Kernel):
|
|
57
|
+
length_scale: Parameter
|
|
58
|
+
variance: Parameter
|
|
59
|
+
|
|
60
|
+
def __init__(self, length_scale: Parameter, variance: Parameter):
|
|
61
|
+
self.length_scale = length_scale
|
|
62
|
+
self.variance = variance
|
|
63
|
+
|
|
64
|
+
def feature_weights(self, freqs: Array) -> Array:
|
|
65
|
+
return matern_kernel_fw_nd(freqs, self.length_scale.val, self.variance.val, nu=1.5, n=2)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Matern52(Kernel):
|
|
69
|
+
length_scale: Parameter
|
|
70
|
+
variance: Parameter
|
|
71
|
+
|
|
72
|
+
def __init__(self, length_scale: Parameter, variance: Parameter):
|
|
73
|
+
self.length_scale = length_scale
|
|
74
|
+
self.variance = variance
|
|
75
|
+
|
|
76
|
+
def feature_weights(self, freqs: Array) -> Array:
|
|
77
|
+
return matern_kernel_fw_nd(freqs, self.length_scale.val, self.variance.val, nu=2.5, n=2)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SquaredExponential(Kernel):
|
|
81
|
+
length_scale: Parameter
|
|
82
|
+
variance: Parameter
|
|
83
|
+
|
|
84
|
+
def __init__(self, length_scale: Parameter, variance: Parameter):
|
|
85
|
+
self.length_scale = length_scale
|
|
86
|
+
self.variance = variance
|
|
87
|
+
|
|
88
|
+
def feature_weights(self, freqs: Array) -> Array:
|
|
89
|
+
fw = jnp.exp(-0.25 * freqs**2 * self.length_scale.val**2 + 1e-4)
|
|
90
|
+
return jnp.sqrt(self.variance.val) * normalise_fw(fw)
|