astroemu 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astroemu-0.1.2/LICENSE +21 -0
- astroemu-0.1.2/PKG-INFO +59 -0
- astroemu-0.1.2/README.md +44 -0
- astroemu-0.1.2/astroemu/__init__.py +1 -0
- astroemu-0.1.2/astroemu/_version.py +1 -0
- astroemu-0.1.2/astroemu/dataloaders.py +154 -0
- astroemu-0.1.2/astroemu/network.py +88 -0
- astroemu-0.1.2/astroemu/normalisation.py +164 -0
- astroemu-0.1.2/astroemu/utils.py +83 -0
- astroemu-0.1.2/astroemu.egg-info/PKG-INFO +59 -0
- astroemu-0.1.2/astroemu.egg-info/SOURCES.txt +14 -0
- astroemu-0.1.2/astroemu.egg-info/dependency_links.txt +1 -0
- astroemu-0.1.2/astroemu.egg-info/requires.txt +7 -0
- astroemu-0.1.2/astroemu.egg-info/top_level.txt +1 -0
- astroemu-0.1.2/pyproject.toml +39 -0
- astroemu-0.1.2/setup.cfg +4 -0
astroemu-0.1.2/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Harry Bevins
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
astroemu-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: astroemu
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Generalised framework for emulating spectral signals.
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: jax
|
|
9
|
+
Requires-Dist: jaxlib
|
|
10
|
+
Requires-Dist: torch
|
|
11
|
+
Provides-Extra: dev
|
|
12
|
+
Requires-Dist: ruff; extra == "dev"
|
|
13
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
14
|
+
Dynamic: license-file
|
|
15
|
+
|
|
16
|
+
# Next Generation Emulators for Cosmology and Astrophysics
|
|
17
|
+
|
|
18
|
+
| | |
|
|
19
|
+
|---| ---|
|
|
20
|
+
| Author| Harry Bevins|
|
|
21
|
+
| Version| 0.1.2 |
|
|
22
|
+
| Homepage | https://github.com/htjb/astroemu |
|
|
23
|
+
|
|
24
|
+
UNDER DEVELOPMENT
|
|
25
|
+
|
|
26
|
+
`astroemu` implements a generalized framework for emulating
|
|
27
|
+
spectral signals
|
|
28
|
+
and is inspired by the [`globalemu`](https://github.com/htjb/globalemu) package.
|
|
29
|
+
|
|
30
|
+
The neural network emulators are implemented in JAX and the dataloaders are
|
|
31
|
+
built on top of PyTorch.
|
|
32
|
+
|
|
33
|
+
As with `globalemu` the idea is to input the independent variables alongside
|
|
34
|
+
the physical parameters of your model then predicting a single corresponding
|
|
35
|
+
spectral value. Full spectra can then be generated via a vectorised call to
|
|
36
|
+
the network. The training data is tiled in the dataloaders so that the
|
|
37
|
+
parameters and independent variables are concatenated as inputs and
|
|
38
|
+
stacked up alongside the outputs. For example if we have a signal
|
|
39
|
+
$y = f(x, \theta)$ and we have N $\theta$ samples and m $x$ and $y$ values then
|
|
40
|
+
our training data looks like
|
|
41
|
+
|
|
42
|
+
|Input|Output|
|
|
43
|
+
|--|--|
|
|
44
|
+
|[$\theta_{0}$, $x_0$]| $y_0$ |
|
|
45
|
+
|[$\theta_{0}$, $x_1$]| $y_1$ |
|
|
46
|
+
|[$\theta_{0}$, $x_2$]| $y_2$ |
|
|
47
|
+
|[$\theta_{0}$, ...]|...|
|
|
48
|
+
|[$\theta_{0}$, $x_m$]|$y_m$|
|
|
49
|
+
|[..., ...]| ...|
|
|
50
|
+
|[$\theta_N$, $x_m$]|$y_m$|
|
|
51
|
+
|
|
52
|
+
For more details see the `globalemu` [paper](https://arxiv.org/abs/2104.04336).
|
|
53
|
+
A paper is in preparation demonstrating applications of this package to a broad
|
|
54
|
+
range of astrophysical signals.
|
|
55
|
+
|
|
56
|
+
## Contributions
|
|
57
|
+
|
|
58
|
+
Contributions are welcome! Please open an issue to discuss and have a
|
|
59
|
+
read of the Contribution guidelines.
|
astroemu-0.1.2/README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Next Generation Emulators for Cosmology and Astrophysics
|
|
2
|
+
|
|
3
|
+
| | |
|
|
4
|
+
|---| ---|
|
|
5
|
+
| Author| Harry Bevins|
|
|
6
|
+
| Version| 0.1.2 |
|
|
7
|
+
| Homepage | https://github.com/htjb/astroemu |
|
|
8
|
+
|
|
9
|
+
UNDER DEVELOPMENT
|
|
10
|
+
|
|
11
|
+
`astroemu` implements a generalized framework for emulating
|
|
12
|
+
spectral signals
|
|
13
|
+
and is inspired by the [`globalemu`](https://github.com/htjb/globalemu) package.
|
|
14
|
+
|
|
15
|
+
The neural network emulators are implemented in JAX and the dataloaders are
|
|
16
|
+
built on top of PyTorch.
|
|
17
|
+
|
|
18
|
+
As with `globalemu` the idea is to input the independent variables alongside
|
|
19
|
+
the physical parameters of your model then predicting a single corresponding
|
|
20
|
+
spectral value. Full spectra can then be generated via a vectorised call to
|
|
21
|
+
the network. The training data is tiled in the dataloaders so that the
|
|
22
|
+
parameters and independent variables are concatenated as inputs and
|
|
23
|
+
stacked up alongside the outputs. For example if we have a signal
|
|
24
|
+
$y = f(x, \theta)$ and we have N $\theta$ samples and m $x$ and $y$ values then
|
|
25
|
+
our training data looks like
|
|
26
|
+
|
|
27
|
+
|Input|Output|
|
|
28
|
+
|--|--|
|
|
29
|
+
|[$\theta_{0}$, $x_0$]| $y_0$ |
|
|
30
|
+
|[$\theta_{0}$, $x_1$]| $y_1$ |
|
|
31
|
+
|[$\theta_{0}$, $x_2$]| $y_2$ |
|
|
32
|
+
|[$\theta_{0}$, ...]|...|
|
|
33
|
+
|[$\theta_{0}$, $x_m$]|$y_m$|
|
|
34
|
+
|[..., ...]| ...|
|
|
35
|
+
|[$\theta_N$, $x_m$]|$y_m$|
|
|
36
|
+
|
|
37
|
+
For more details see the `globalemu` [paper](https://arxiv.org/abs/2104.04336).
|
|
38
|
+
A paper is in preparation demonstrating applications of this package to a broad
|
|
39
|
+
range of astrophysical signals.
|
|
40
|
+
|
|
41
|
+
## Contributions
|
|
42
|
+
|
|
43
|
+
Contributions are welcome! Please open an issue to discuss and have a
|
|
44
|
+
read of the Contribution guidelines.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""init file for emu package."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.2"
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Data loaders for emu package."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils.data import Dataset
|
|
6
|
+
|
|
7
|
+
from emu.normalisation import NormalisationPipeline
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_spectrum(file: str) -> dict:
|
|
11
|
+
"""Load spectrum data from .npz file.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
file (str): Path to .npz file.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
dict: Dictionary containing data from .npz file.
|
|
18
|
+
"""
|
|
19
|
+
data = jnp.load(file)
|
|
20
|
+
input = {k: data[k] for k in data.files}
|
|
21
|
+
return input
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SpectrumDataset(Dataset):
|
|
25
|
+
"""Dataset for loading spectra from .npz files.
|
|
26
|
+
|
|
27
|
+
Allows for optional preprocessing via a forward pipeline and
|
|
28
|
+
selection of variable input parameters.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
files: list[str],
|
|
34
|
+
x: str,
|
|
35
|
+
y: str,
|
|
36
|
+
forward_pipeline: NormalisationPipeline | None = None,
|
|
37
|
+
variable_input: list[str] | str | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Initialize SpectrumDataset.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
files (list[str]): List of file paths to .npz files.
|
|
43
|
+
x (str): Key for independent variable in .npz files.
|
|
44
|
+
y (str): Key for dependent variable in .npz files.
|
|
45
|
+
forward_pipeline (Any, optional): Preprocessing pipeline.
|
|
46
|
+
Defaults to None.
|
|
47
|
+
variable_input (list[str] | str | None, optional): Keys
|
|
48
|
+
for variable input parameters.
|
|
49
|
+
If None, all parameters except x and y are used.
|
|
50
|
+
Defaults to None.
|
|
51
|
+
"""
|
|
52
|
+
self.files = files
|
|
53
|
+
self.varied_input = variable_input
|
|
54
|
+
self.forward_pipeline = forward_pipeline
|
|
55
|
+
self.x = x
|
|
56
|
+
self.y = y
|
|
57
|
+
|
|
58
|
+
def __len__(self) -> int:
|
|
59
|
+
"""Return number of files in dataset."""
|
|
60
|
+
return len(self.files)
|
|
61
|
+
|
|
62
|
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
63
|
+
"""Get spectrum and input parameters for given index.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
idx (int): Index of the data point.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
tuple[torch.Tensor, torch.Tensor]: Tuple of (spectrum,
|
|
70
|
+
input parameters).
|
|
71
|
+
"""
|
|
72
|
+
input = load_spectrum(self.files[idx])
|
|
73
|
+
x = torch.tensor(input[self.x])
|
|
74
|
+
y = torch.tensor(input[self.y])
|
|
75
|
+
if self.varied_input:
|
|
76
|
+
input = torch.tensor(
|
|
77
|
+
[input[k].item() for k in self.varied_input],
|
|
78
|
+
dtype=torch.float32,
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
input = torch.tensor(
|
|
82
|
+
[
|
|
83
|
+
input[k].item()
|
|
84
|
+
for k in sorted(input.keys())
|
|
85
|
+
if k not in [self.x, self.y]
|
|
86
|
+
],
|
|
87
|
+
dtype=torch.float32,
|
|
88
|
+
)
|
|
89
|
+
input = torch.tile(
|
|
90
|
+
input, (y.shape[0], 1)
|
|
91
|
+
) # Ensure input shape matches spec
|
|
92
|
+
input = torch.cat(
|
|
93
|
+
[x[:, None], input], axis=1
|
|
94
|
+
) # Concatenate wavelength with parameters
|
|
95
|
+
if self.forward_pipeline:
|
|
96
|
+
return self.forward_pipeline.forward(y, input)
|
|
97
|
+
else:
|
|
98
|
+
return y, input
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class NormalizeSpectrumDataset(SpectrumDataset):
|
|
102
|
+
"""Dataset for loading and normalizing spectra from .npz files.
|
|
103
|
+
|
|
104
|
+
Applies a super forward pipeline followed by a normalization pipeline.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
files: list[str],
|
|
110
|
+
x: str,
|
|
111
|
+
y: str,
|
|
112
|
+
forward_pipeline: NormalisationPipeline | None = None,
|
|
113
|
+
super_forward_pipeline: NormalisationPipeline | None = None,
|
|
114
|
+
variable_input: list[str] | str | None = None,
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Initialize NormalizeSpectrumDataset.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
files (list[str]): List of file paths to .npz files.
|
|
120
|
+
x (str): Key for independent variable in .npz files.
|
|
121
|
+
y (str): Key for dependent variable in .npz files.
|
|
122
|
+
forward_pipeline (Any, optional): Normalization pipeline.
|
|
123
|
+
Defaults to None.
|
|
124
|
+
super_forward_pipeline (Any, optional): Preprocessing pipeline.
|
|
125
|
+
Defaults to None.
|
|
126
|
+
variable_input (list[str] | str | None, optional): Keys
|
|
127
|
+
for variable input parameters.
|
|
128
|
+
If None, all parameters except x and y are used.
|
|
129
|
+
Defaults to None.
|
|
130
|
+
"""
|
|
131
|
+
super().__init__(
|
|
132
|
+
files,
|
|
133
|
+
x,
|
|
134
|
+
y,
|
|
135
|
+
forward_pipeline=super_forward_pipeline,
|
|
136
|
+
variable_input=variable_input,
|
|
137
|
+
)
|
|
138
|
+
self.normalize_pipeline = forward_pipeline
|
|
139
|
+
|
|
140
|
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
141
|
+
"""Get normalized spectrum and input parameters for given index.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
idx (int): Index of the data point.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
tuple[torch.Tensor, torch.Tensor]: Tuple of (normalized spectrum,
|
|
148
|
+
input parameters).
|
|
149
|
+
"""
|
|
150
|
+
y, input = super().__getitem__(idx)
|
|
151
|
+
if self.normalize_pipeline:
|
|
152
|
+
return self.normalize_pipeline.forward(y, input)
|
|
153
|
+
else:
|
|
154
|
+
return y, input
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Neural network implementations for emu package."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax import random
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def initialise_mlp(
|
|
9
|
+
in_size: int,
|
|
10
|
+
out_size: int,
|
|
11
|
+
hidden_size: int,
|
|
12
|
+
nlayers: int,
|
|
13
|
+
key: int,
|
|
14
|
+
scale: float = 1e-1,
|
|
15
|
+
) -> dict:
|
|
16
|
+
"""Initialize MLP parameters.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
in_size (int): Input size.
|
|
20
|
+
out_size (int): Output size.
|
|
21
|
+
hidden_size (int): Hidden layer size.
|
|
22
|
+
nlayers (int): Number of hidden layers.
|
|
23
|
+
key (int): JAX random key.
|
|
24
|
+
scale (float, optional): Scale for weight initialization.
|
|
25
|
+
Defaults to 1e-1.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
dict: MLP parameters.
|
|
29
|
+
"""
|
|
30
|
+
keys = random.split(key, nlayers * 2 + 2 + 2)
|
|
31
|
+
weights = (
|
|
32
|
+
[
|
|
33
|
+
{
|
|
34
|
+
"weights" + str(0): scale
|
|
35
|
+
* random.normal(keys[0], (in_size, hidden_size)),
|
|
36
|
+
"bias" + str(0): scale
|
|
37
|
+
* random.normal(keys[1], (hidden_size,)),
|
|
38
|
+
}
|
|
39
|
+
]
|
|
40
|
+
+ [
|
|
41
|
+
{
|
|
42
|
+
"weights" + str(i + 1): scale
|
|
43
|
+
* random.normal(keys[i + 2], (hidden_size, hidden_size)),
|
|
44
|
+
"bias" + str(i + 1): scale
|
|
45
|
+
* random.normal(keys[i + 3], (hidden_size,)),
|
|
46
|
+
}
|
|
47
|
+
for i in range(nlayers)
|
|
48
|
+
]
|
|
49
|
+
+ [
|
|
50
|
+
{
|
|
51
|
+
"weights" + str(nlayers + 1): scale
|
|
52
|
+
* random.normal(keys[-2], (hidden_size, out_size)),
|
|
53
|
+
"bias" + str(nlayers + 1): scale
|
|
54
|
+
* random.normal(keys[-1], (out_size,)),
|
|
55
|
+
}
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
return {k: v for d in weights for k, v in d.items()}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def mlp(params: dict, input: jnp.ndarray) -> jnp.ndarray:
|
|
62
|
+
"""Multi-layer perceptron with residual connections.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
params (dict): MLP parameters.
|
|
66
|
+
input (jnp.ndarray): Input array of shape [..., in_size].
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
jnp.ndarray: Output array of shape [..., out_size].
|
|
70
|
+
"""
|
|
71
|
+
act_fn = getattr(jax.nn, "relu")
|
|
72
|
+
num_layers = len(params) // 2 # total layers: input + hidden(s) + output
|
|
73
|
+
|
|
74
|
+
x = jnp.dot(input, params["weights0"]) + params["bias0"]
|
|
75
|
+
|
|
76
|
+
for i in range(1, num_layers - 1): # exclude final output layer
|
|
77
|
+
residual = x
|
|
78
|
+
x = act_fn(x)
|
|
79
|
+
x = jnp.dot(x, params[f"weights{i}"]) + params[f"bias{i}"]
|
|
80
|
+
# Residual connection (only if shapes match)
|
|
81
|
+
x += residual
|
|
82
|
+
|
|
83
|
+
# Final layer: linear only, no activation
|
|
84
|
+
output = (
|
|
85
|
+
jnp.dot(x, params[f"weights{num_layers - 1}"])
|
|
86
|
+
+ params[f"bias{num_layers - 1}"]
|
|
87
|
+
)
|
|
88
|
+
return output
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Normalisation pipelines for emu package."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NormalisationPipeline:
|
|
7
|
+
"""Base class for normalisation pipelines."""
|
|
8
|
+
|
|
9
|
+
def forward(
|
|
10
|
+
self, y: torch.Tensor, x: torch.Tensor | None = None
|
|
11
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
12
|
+
"""Apply forward transformation."""
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
def backward(
|
|
16
|
+
self, y: torch.Tensor, x: torch.Tensor | None = None
|
|
17
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
18
|
+
"""Apply backward transformation."""
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class standardise(NormalisationPipeline):
|
|
23
|
+
"""Standardisation normalisation pipeline."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
y_mean: torch.Tensor,
|
|
28
|
+
y_std: torch.Tensor,
|
|
29
|
+
x_mean: torch.Tensor,
|
|
30
|
+
x_std: torch.Tensor,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Standardises the spectrum and input parameters.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
y_mean (float): Mean of the spectrum for standardisation.
|
|
36
|
+
y_std (float): Standard deviation of the spectrum for
|
|
37
|
+
standardisation.
|
|
38
|
+
x_mean (float): Mean of the input parameters for standardisation.
|
|
39
|
+
x_std (float): Standard deviation of the input parameters
|
|
40
|
+
for standardisation.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
tuple: Standardised spectrum and input parameters.
|
|
44
|
+
"""
|
|
45
|
+
self.y_mean = y_mean
|
|
46
|
+
self.y_std = y_std
|
|
47
|
+
self.x_mean = x_mean
|
|
48
|
+
self.x_std = x_std
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self, y: torch.Tensor, x: torch.Tensor | None = None
|
|
52
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
53
|
+
"""Standardise the spectrum and input parameters.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
y (torch.Tensor): Spectrum tensor.
|
|
57
|
+
x (torch.Tensor, optional): Input parameters tensor.
|
|
58
|
+
Defaults to None.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
tuple: Standardised spectrum and input parameters.
|
|
62
|
+
"""
|
|
63
|
+
y = (y - self.y_mean) / self.y_std
|
|
64
|
+
if x is not None:
|
|
65
|
+
x = (x - self.x_mean) / self.x_std
|
|
66
|
+
return y, x
|
|
67
|
+
|
|
68
|
+
def backward(
|
|
69
|
+
self, y: torch.Tensor, x: torch.Tensor | None = None
|
|
70
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
71
|
+
"""Destandardise the spectrum and input parameters.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
y (torch.Tensor): Standardised spectrum tensor.
|
|
75
|
+
x (torch.Tensor, optional): Standardised input parameters tensor.
|
|
76
|
+
Defaults to None.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
tuple: Destandardised spectrum and input parameters.
|
|
80
|
+
"""
|
|
81
|
+
y = y * self.y_std + self.y_mean
|
|
82
|
+
if x is not None:
|
|
83
|
+
x = x * self.x_std + self.x_mean
|
|
84
|
+
return y, x
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class log_base_10(NormalisationPipeline):
|
|
88
|
+
"""Logarithm base 10 transformation for numerical stability."""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
yselector: list[int] | None = None,
|
|
93
|
+
xselector: list[int] | None = None,
|
|
94
|
+
eps: float = 1e-15,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Logarithm base 10 transformation for numerical stability.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
yselector (list[int] | None): columns of the spectrum to
|
|
100
|
+
apply log transformation.
|
|
101
|
+
Assumes that the spectra are in the last dimension.
|
|
102
|
+
xselector (list[int] | None): columns of the input parameters to
|
|
103
|
+
apply log transformation.
|
|
104
|
+
eps (float): small value to add to avoid log(0).
|
|
105
|
+
"""
|
|
106
|
+
self.yselector = yselector
|
|
107
|
+
self.xselector = xselector
|
|
108
|
+
self.eps = eps
|
|
109
|
+
|
|
110
|
+
def forward(
|
|
111
|
+
self, y: torch.Tensor, x: torch.Tensor | None = None
|
|
112
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
113
|
+
"""Apply log10 transformation to selected columns.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
y (torch.Tensor): Spectrum tensor.
|
|
117
|
+
x (torch.Tensor, optional): Input parameters tensor.
|
|
118
|
+
Defaults to None.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
tuple: Transformed spectrum and input parameters.
|
|
122
|
+
"""
|
|
123
|
+
if x is not None:
|
|
124
|
+
if self.xselector is not None:
|
|
125
|
+
for i in self.xselector:
|
|
126
|
+
x[..., i] = torch.log10(x[..., i] + self.eps)
|
|
127
|
+
else:
|
|
128
|
+
x = torch.log10(x + self.eps)
|
|
129
|
+
|
|
130
|
+
if self.yselector is not None:
|
|
131
|
+
for i in self.yselector:
|
|
132
|
+
y[..., i] = torch.log10(y[..., i] + self.eps)
|
|
133
|
+
else:
|
|
134
|
+
y = torch.log10(y + self.eps)
|
|
135
|
+
|
|
136
|
+
return y, x
|
|
137
|
+
|
|
138
|
+
def backward(
|
|
139
|
+
self, y: torch.Tensor, x: torch.Tensor | None = None
|
|
140
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
141
|
+
"""Apply inverse log10 transformation to selected columns.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
y (torch.Tensor): Transformed spectrum tensor.
|
|
145
|
+
x (torch.Tensor, optional): Transformed input parameters tensor.
|
|
146
|
+
Defaults to None.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
tuple: Inverse transformed spectrum and input parameters.
|
|
150
|
+
"""
|
|
151
|
+
if x is not None:
|
|
152
|
+
if self.xselector is not None:
|
|
153
|
+
for i in self.xselector:
|
|
154
|
+
x[..., i] = torch.pow(10, x[..., i]) - self.eps
|
|
155
|
+
else:
|
|
156
|
+
x = torch.pow(10, x) - self.eps
|
|
157
|
+
|
|
158
|
+
if self.yselector is not None:
|
|
159
|
+
for i in self.yselector:
|
|
160
|
+
y[..., i] = torch.pow(10, y[..., i]) - self.eps
|
|
161
|
+
else:
|
|
162
|
+
y = torch.pow(10, y) - self.eps
|
|
163
|
+
|
|
164
|
+
return y, x
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Utility functions for emu package."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_mean_std(
|
|
7
|
+
loader: torch.utils.data.DataLoader,
|
|
8
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
9
|
+
"""Memory safe mean and std computation.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
loader: DataLoader returning (spec, input) where:
|
|
13
|
+
- spec: [batch_size, 5000]
|
|
14
|
+
- input: [batch_size, 5000, N]
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
mean_spec: [5000] - mean across batches
|
|
18
|
+
std_spec: [5000] - std across batches
|
|
19
|
+
mean_input: [N] - mean across batches and the 5000 dimension
|
|
20
|
+
std_input: [N] - std across batches and the 5000 dimension
|
|
21
|
+
"""
|
|
22
|
+
# Accumulators
|
|
23
|
+
spec_sum = None
|
|
24
|
+
spec_sum_sq = None
|
|
25
|
+
input_sum = None
|
|
26
|
+
input_sum_sq = None
|
|
27
|
+
n_spec_samples = 0
|
|
28
|
+
n_input_samples = 0
|
|
29
|
+
|
|
30
|
+
for spec, input_data in loader:
|
|
31
|
+
batch_size = spec.size(0)
|
|
32
|
+
|
|
33
|
+
# === Process spec ===
|
|
34
|
+
# spec shape: [batch_size, 5000] -> we want stats across batch dim
|
|
35
|
+
if spec_sum is None:
|
|
36
|
+
spec_sum = torch.zeros(
|
|
37
|
+
spec.size(1), dtype=spec.dtype, device=spec.device
|
|
38
|
+
)
|
|
39
|
+
spec_sum_sq = torch.zeros(
|
|
40
|
+
spec.size(1), dtype=spec.dtype, device=spec.device
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
spec_sum += spec.sum(dim=0) # sum across batch
|
|
44
|
+
spec_sum_sq += (spec**2).sum(dim=0) # sum of squares across batch
|
|
45
|
+
n_spec_samples += batch_size
|
|
46
|
+
|
|
47
|
+
# === Process input ===
|
|
48
|
+
# input shape: [batch_size, 5000, N] -> we want stats across
|
|
49
|
+
# batch and 5000 dims
|
|
50
|
+
input_flat = input_data.view(
|
|
51
|
+
-1, input_data.size(-1)
|
|
52
|
+
) # [batch_size * 5000, N]
|
|
53
|
+
|
|
54
|
+
if input_sum is None:
|
|
55
|
+
input_sum = torch.zeros(
|
|
56
|
+
input_data.size(-1),
|
|
57
|
+
dtype=input_data.dtype,
|
|
58
|
+
device=input_data.device,
|
|
59
|
+
)
|
|
60
|
+
input_sum_sq = torch.zeros(
|
|
61
|
+
input_data.size(-1),
|
|
62
|
+
dtype=input_data.dtype,
|
|
63
|
+
device=input_data.device,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
input_sum += input_flat.sum(
|
|
67
|
+
dim=0
|
|
68
|
+
) # sum across flattened batch*5000 dim
|
|
69
|
+
input_sum_sq += (input_flat**2).sum(dim=0) # sum of squares
|
|
70
|
+
n_input_samples += input_flat.size(0) # batch_size * 5000
|
|
71
|
+
|
|
72
|
+
# Compute means and stds
|
|
73
|
+
mean_spec = spec_sum / n_spec_samples
|
|
74
|
+
var_spec = (spec_sum_sq / n_spec_samples) - (mean_spec**2)
|
|
75
|
+
std_spec = torch.where(var_spec <= 1e-3, 1, torch.sqrt(var_spec))
|
|
76
|
+
|
|
77
|
+
mean_input = input_sum / n_input_samples
|
|
78
|
+
var_input = (input_sum_sq / n_input_samples) - (mean_input**2)
|
|
79
|
+
std_input = torch.sqrt(
|
|
80
|
+
torch.clamp(var_input, min=1e-8)
|
|
81
|
+
) # clamp for numerical stability
|
|
82
|
+
|
|
83
|
+
return mean_spec, std_spec, mean_input, std_input
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: astroemu
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Generalised framework for emulating spectral signals.
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: jax
|
|
9
|
+
Requires-Dist: jaxlib
|
|
10
|
+
Requires-Dist: torch
|
|
11
|
+
Provides-Extra: dev
|
|
12
|
+
Requires-Dist: ruff; extra == "dev"
|
|
13
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
14
|
+
Dynamic: license-file
|
|
15
|
+
|
|
16
|
+
# Next Generation Emulators for Cosmology and Astrophysics
|
|
17
|
+
|
|
18
|
+
| | |
|
|
19
|
+
|---| ---|
|
|
20
|
+
| Author| Harry Bevins|
|
|
21
|
+
| Version| 0.1.2 |
|
|
22
|
+
| Homepage | https://github.com/htjb/astroemu |
|
|
23
|
+
|
|
24
|
+
UNDER DEVELOPMENT
|
|
25
|
+
|
|
26
|
+
`astroemu` implements a generalized framework for emulating
|
|
27
|
+
spectral signals
|
|
28
|
+
and is inspired by the [`globalemu`](https://github.com/htjb/globalemu) package.
|
|
29
|
+
|
|
30
|
+
The neural network emulators are implemented in JAX and the dataloaders are
|
|
31
|
+
built on top of PyTorch.
|
|
32
|
+
|
|
33
|
+
As with `globalemu` the idea is to input the independent variables alongside
|
|
34
|
+
the physical parameters of your model then predicting a single corresponding
|
|
35
|
+
spectral value. Full spectra can then be generated via a vectorised call to
|
|
36
|
+
the network. The training data is tiled in the dataloaders so that the
|
|
37
|
+
parameters and independent variables are concatenated as inputs and
|
|
38
|
+
stacked up alongside the outputs. For example if we have a signal
|
|
39
|
+
$y = f(x, \theta)$ and we have N $\theta$ samples and m $x$ and $y$ values then
|
|
40
|
+
our training data looks like
|
|
41
|
+
|
|
42
|
+
|Input|Output|
|
|
43
|
+
|--|--|
|
|
44
|
+
|[$\theta_{0}$, $x_0$]| $y_0$ |
|
|
45
|
+
|[$\theta_{0}$, $x_1$]| $y_1$ |
|
|
46
|
+
|[$\theta_{0}$, $x_2$]| $y_2$ |
|
|
47
|
+
|[$\theta_{0}$, ...]|...|
|
|
48
|
+
|[$\theta_{0}$, $x_m$]|$y_m$|
|
|
49
|
+
|[..., ...]| ...|
|
|
50
|
+
|[$\theta_N$, $x_m$]|$y_m$|
|
|
51
|
+
|
|
52
|
+
For more details see the `globalemu` [paper](https://arxiv.org/abs/2104.04336).
|
|
53
|
+
A paper is in preparation demonstrating applications of this package to a broad
|
|
54
|
+
range of astrophysical signals.
|
|
55
|
+
|
|
56
|
+
## Contributions
|
|
57
|
+
|
|
58
|
+
Contributions are welcome! Please open an issue to discuss and have a
|
|
59
|
+
read of the Contribution guidelines.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
astroemu/__init__.py
|
|
5
|
+
astroemu/_version.py
|
|
6
|
+
astroemu/dataloaders.py
|
|
7
|
+
astroemu/network.py
|
|
8
|
+
astroemu/normalisation.py
|
|
9
|
+
astroemu/utils.py
|
|
10
|
+
astroemu.egg-info/PKG-INFO
|
|
11
|
+
astroemu.egg-info/SOURCES.txt
|
|
12
|
+
astroemu.egg-info/dependency_links.txt
|
|
13
|
+
astroemu.egg-info/requires.txt
|
|
14
|
+
astroemu.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
astroemu
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "astroemu"
|
|
3
|
+
version = "0.1.2"
|
|
4
|
+
description = "Generalised framework for emulating spectral signals."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.11"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"jax",
|
|
9
|
+
"jaxlib",
|
|
10
|
+
"torch",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[build-system]
|
|
14
|
+
requires = ["setuptools>=77.0", "wheel"]
|
|
15
|
+
build-backend = "setuptools.build_meta"
|
|
16
|
+
|
|
17
|
+
[tool.setuptools.packages.find]
|
|
18
|
+
where = ["."]
|
|
19
|
+
include = ["astroemu*"]
|
|
20
|
+
|
|
21
|
+
[project.optional-dependencies]
|
|
22
|
+
dev = [
|
|
23
|
+
"ruff",
|
|
24
|
+
"pre-commit",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[tool.ruff]
|
|
28
|
+
line-length = 79
|
|
29
|
+
|
|
30
|
+
[tool.ruff.lint]
|
|
31
|
+
select =["E", "F", "W", # basics
|
|
32
|
+
"I", # isort
|
|
33
|
+
"D", # docstrings
|
|
34
|
+
"UP", # pyupgrade (includes type modernization)
|
|
35
|
+
"ANN", # type annotations
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[tool.ruff.lint.pydocstyle]
|
|
39
|
+
convention = "google"
|
astroemu-0.1.2/setup.cfg
ADDED