astroemu 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
astroemu-0.1.2/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Harry Bevins
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,59 @@
1
+ Metadata-Version: 2.4
2
+ Name: astroemu
3
+ Version: 0.1.2
4
+ Summary: Generalised framework for emulating spectral signals.
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: jax
9
+ Requires-Dist: jaxlib
10
+ Requires-Dist: torch
11
+ Provides-Extra: dev
12
+ Requires-Dist: ruff; extra == "dev"
13
+ Requires-Dist: pre-commit; extra == "dev"
14
+ Dynamic: license-file
15
+
16
+ # Next Generation Emulators for Cosmology and Astrophysics
17
+
18
+ | | |
19
+ |---| ---|
20
+ | Author| Harry Bevins|
21
+ | Version| 0.1.2 |
22
+ | Homepage | https://github.com/htjb/astroemu |
23
+
24
+ UNDER DEVELOPMENT
25
+
26
+ `astroemu` implements a generalized framework for emulating
27
+ spectral signals
28
+ and is inspired by the [`globalemu`](https://github.com/htjb/globalemu) package.
29
+
30
+ The neural network emulators are implemented in JAX and the dataloaders are
31
+ built on top of PyTorch.
32
+
33
+ As with `globalemu` the idea is to input the independent variables alongside
34
+ the physical parameters of your model then predicting a single corresponding
35
+ spectral value. Full spectra can then be generated via a vectorised call to
36
+ the network. The training data is tiled in the dataloaders so that the
37
+ parameters and independent variables are concatenated as inputs and
38
+ stacked up alongside the outputs. For example if we have a signal
39
+ $y = f(x, \theta)$ and we have N $\theta$ samples and m $x$ and $y$ values then
40
+ our training data looks like
41
+
42
+ |Input|Output|
43
+ |--|--|
44
+ |[$\theta_{0}$, $x_0$]| $y_0$ |
45
+ |[$\theta_{0}$, $x_1$]| $y_1$ |
46
+ |[$\theta_{0}$, $x_2$]| $y_2$ |
47
+ |[$\theta_{0}$, ...]|...|
48
+ |[$\theta_{0}$, $x_m$]|$y_m$|
49
+ |[..., ...]| ...|
50
+ |[$\theta_N$, $x_m$]|$y_m$|
51
+
52
+ For more details see the `globalemu` [paper](https://arxiv.org/abs/2104.04336).
53
+ A paper is in preparation demonstrating applications of this package to a broad
54
+ range of astrophysical signals.
55
+
56
+ ## Contributions
57
+
58
+ Contributions are welcome! Please open an issue to discuss and have a
59
+ read of the Contribution guidelines.
@@ -0,0 +1,44 @@
1
+ # Next Generation Emulators for Cosmology and Astrophysics
2
+
3
+ | | |
4
+ |---| ---|
5
+ | Author| Harry Bevins|
6
+ | Version| 0.1.2 |
7
+ | Homepage | https://github.com/htjb/astroemu |
8
+
9
+ UNDER DEVELOPMENT
10
+
11
+ `astroemu` implements a generalized framework for emulating
12
+ spectral signals
13
+ and is inspired by the [`globalemu`](https://github.com/htjb/globalemu) package.
14
+
15
+ The neural network emulators are implemented in JAX and the dataloaders are
16
+ built on top of PyTorch.
17
+
18
+ As with `globalemu` the idea is to input the independent variables alongside
19
+ the physical parameters of your model then predicting a single corresponding
20
+ spectral value. Full spectra can then be generated via a vectorised call to
21
+ the network. The training data is tiled in the dataloaders so that the
22
+ parameters and independent variables are concatenated as inputs and
23
+ stacked up alongside the outputs. For example if we have a signal
24
+ $y = f(x, \theta)$ and we have N $\theta$ samples and m $x$ and $y$ values then
25
+ our training data looks like
26
+
27
+ |Input|Output|
28
+ |--|--|
29
+ |[$\theta_{0}$, $x_0$]| $y_0$ |
30
+ |[$\theta_{0}$, $x_1$]| $y_1$ |
31
+ |[$\theta_{0}$, $x_2$]| $y_2$ |
32
+ |[$\theta_{0}$, ...]|...|
33
+ |[$\theta_{0}$, $x_m$]|$y_m$|
34
+ |[..., ...]| ...|
35
+ |[$\theta_N$, $x_m$]|$y_m$|
36
+
37
+ For more details see the `globalemu` [paper](https://arxiv.org/abs/2104.04336).
38
+ A paper is in preparation demonstrating applications of this package to a broad
39
+ range of astrophysical signals.
40
+
41
+ ## Contributions
42
+
43
+ Contributions are welcome! Please open an issue to discuss and have a
44
+ read of the Contribution guidelines.
@@ -0,0 +1 @@
1
+ """init file for emu package."""
@@ -0,0 +1 @@
1
+ __version__ = "0.1.2"
@@ -0,0 +1,154 @@
1
+ """Data loaders for emu package."""
2
+
3
+ import jax.numpy as jnp
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+ from emu.normalisation import NormalisationPipeline
8
+
9
+
10
+ def load_spectrum(file: str) -> dict:
11
+ """Load spectrum data from .npz file.
12
+
13
+ Args:
14
+ file (str): Path to .npz file.
15
+
16
+ Returns:
17
+ dict: Dictionary containing data from .npz file.
18
+ """
19
+ data = jnp.load(file)
20
+ input = {k: data[k] for k in data.files}
21
+ return input
22
+
23
+
24
+ class SpectrumDataset(Dataset):
25
+ """Dataset for loading spectra from .npz files.
26
+
27
+ Allows for optional preprocessing via a forward pipeline and
28
+ selection of variable input parameters.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ files: list[str],
34
+ x: str,
35
+ y: str,
36
+ forward_pipeline: NormalisationPipeline | None = None,
37
+ variable_input: list[str] | str | None = None,
38
+ ) -> None:
39
+ """Initialize SpectrumDataset.
40
+
41
+ Args:
42
+ files (list[str]): List of file paths to .npz files.
43
+ x (str): Key for independent variable in .npz files.
44
+ y (str): Key for dependent variable in .npz files.
45
+ forward_pipeline (Any, optional): Preprocessing pipeline.
46
+ Defaults to None.
47
+ variable_input (list[str] | str | None, optional): Keys
48
+ for variable input parameters.
49
+ If None, all parameters except x and y are used.
50
+ Defaults to None.
51
+ """
52
+ self.files = files
53
+ self.varied_input = variable_input
54
+ self.forward_pipeline = forward_pipeline
55
+ self.x = x
56
+ self.y = y
57
+
58
+ def __len__(self) -> int:
59
+ """Return number of files in dataset."""
60
+ return len(self.files)
61
+
62
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
63
+ """Get spectrum and input parameters for given index.
64
+
65
+ Args:
66
+ idx (int): Index of the data point.
67
+
68
+ Returns:
69
+ tuple[torch.Tensor, torch.Tensor]: Tuple of (spectrum,
70
+ input parameters).
71
+ """
72
+ input = load_spectrum(self.files[idx])
73
+ x = torch.tensor(input[self.x])
74
+ y = torch.tensor(input[self.y])
75
+ if self.varied_input:
76
+ input = torch.tensor(
77
+ [input[k].item() for k in self.varied_input],
78
+ dtype=torch.float32,
79
+ )
80
+ else:
81
+ input = torch.tensor(
82
+ [
83
+ input[k].item()
84
+ for k in sorted(input.keys())
85
+ if k not in [self.x, self.y]
86
+ ],
87
+ dtype=torch.float32,
88
+ )
89
+ input = torch.tile(
90
+ input, (y.shape[0], 1)
91
+ ) # Ensure input shape matches spec
92
+ input = torch.cat(
93
+ [x[:, None], input], axis=1
94
+ ) # Concatenate wavelength with parameters
95
+ if self.forward_pipeline:
96
+ return self.forward_pipeline.forward(y, input)
97
+ else:
98
+ return y, input
99
+
100
+
101
+ class NormalizeSpectrumDataset(SpectrumDataset):
102
+ """Dataset for loading and normalizing spectra from .npz files.
103
+
104
+ Applies a super forward pipeline followed by a normalization pipeline.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ files: list[str],
110
+ x: str,
111
+ y: str,
112
+ forward_pipeline: NormalisationPipeline | None = None,
113
+ super_forward_pipeline: NormalisationPipeline | None = None,
114
+ variable_input: list[str] | str | None = None,
115
+ ) -> None:
116
+ """Initialize NormalizeSpectrumDataset.
117
+
118
+ Args:
119
+ files (list[str]): List of file paths to .npz files.
120
+ x (str): Key for independent variable in .npz files.
121
+ y (str): Key for dependent variable in .npz files.
122
+ forward_pipeline (Any, optional): Normalization pipeline.
123
+ Defaults to None.
124
+ super_forward_pipeline (Any, optional): Preprocessing pipeline.
125
+ Defaults to None.
126
+ variable_input (list[str] | str | None, optional): Keys
127
+ for variable input parameters.
128
+ If None, all parameters except x and y are used.
129
+ Defaults to None.
130
+ """
131
+ super().__init__(
132
+ files,
133
+ x,
134
+ y,
135
+ forward_pipeline=super_forward_pipeline,
136
+ variable_input=variable_input,
137
+ )
138
+ self.normalize_pipeline = forward_pipeline
139
+
140
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
141
+ """Get normalized spectrum and input parameters for given index.
142
+
143
+ Args:
144
+ idx (int): Index of the data point.
145
+
146
+ Returns:
147
+ tuple[torch.Tensor, torch.Tensor]: Tuple of (normalized spectrum,
148
+ input parameters).
149
+ """
150
+ y, input = super().__getitem__(idx)
151
+ if self.normalize_pipeline:
152
+ return self.normalize_pipeline.forward(y, input)
153
+ else:
154
+ return y, input
@@ -0,0 +1,88 @@
1
+ """Neural network implementations for emu package."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax import random
6
+
7
+
8
+ def initialise_mlp(
9
+ in_size: int,
10
+ out_size: int,
11
+ hidden_size: int,
12
+ nlayers: int,
13
+ key: int,
14
+ scale: float = 1e-1,
15
+ ) -> dict:
16
+ """Initialize MLP parameters.
17
+
18
+ Args:
19
+ in_size (int): Input size.
20
+ out_size (int): Output size.
21
+ hidden_size (int): Hidden layer size.
22
+ nlayers (int): Number of hidden layers.
23
+ key (int): JAX random key.
24
+ scale (float, optional): Scale for weight initialization.
25
+ Defaults to 1e-1.
26
+
27
+ Returns:
28
+ dict: MLP parameters.
29
+ """
30
+ keys = random.split(key, nlayers * 2 + 2 + 2)
31
+ weights = (
32
+ [
33
+ {
34
+ "weights" + str(0): scale
35
+ * random.normal(keys[0], (in_size, hidden_size)),
36
+ "bias" + str(0): scale
37
+ * random.normal(keys[1], (hidden_size,)),
38
+ }
39
+ ]
40
+ + [
41
+ {
42
+ "weights" + str(i + 1): scale
43
+ * random.normal(keys[i + 2], (hidden_size, hidden_size)),
44
+ "bias" + str(i + 1): scale
45
+ * random.normal(keys[i + 3], (hidden_size,)),
46
+ }
47
+ for i in range(nlayers)
48
+ ]
49
+ + [
50
+ {
51
+ "weights" + str(nlayers + 1): scale
52
+ * random.normal(keys[-2], (hidden_size, out_size)),
53
+ "bias" + str(nlayers + 1): scale
54
+ * random.normal(keys[-1], (out_size,)),
55
+ }
56
+ ]
57
+ )
58
+ return {k: v for d in weights for k, v in d.items()}
59
+
60
+
61
+ def mlp(params: dict, input: jnp.ndarray) -> jnp.ndarray:
62
+ """Multi-layer perceptron with residual connections.
63
+
64
+ Args:
65
+ params (dict): MLP parameters.
66
+ input (jnp.ndarray): Input array of shape [..., in_size].
67
+
68
+ Returns:
69
+ jnp.ndarray: Output array of shape [..., out_size].
70
+ """
71
+ act_fn = getattr(jax.nn, "relu")
72
+ num_layers = len(params) // 2 # total layers: input + hidden(s) + output
73
+
74
+ x = jnp.dot(input, params["weights0"]) + params["bias0"]
75
+
76
+ for i in range(1, num_layers - 1): # exclude final output layer
77
+ residual = x
78
+ x = act_fn(x)
79
+ x = jnp.dot(x, params[f"weights{i}"]) + params[f"bias{i}"]
80
+ # Residual connection (only if shapes match)
81
+ x += residual
82
+
83
+ # Final layer: linear only, no activation
84
+ output = (
85
+ jnp.dot(x, params[f"weights{num_layers - 1}"])
86
+ + params[f"bias{num_layers - 1}"]
87
+ )
88
+ return output
@@ -0,0 +1,164 @@
1
+ """Normalisation pipelines for emu package."""
2
+
3
+ import torch
4
+
5
+
6
+ class NormalisationPipeline:
7
+ """Base class for normalisation pipelines."""
8
+
9
+ def forward(
10
+ self, y: torch.Tensor, x: torch.Tensor | None = None
11
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
12
+ """Apply forward transformation."""
13
+ raise NotImplementedError
14
+
15
+ def backward(
16
+ self, y: torch.Tensor, x: torch.Tensor | None = None
17
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
18
+ """Apply backward transformation."""
19
+ raise NotImplementedError
20
+
21
+
22
+ class standardise(NormalisationPipeline):
23
+ """Standardisation normalisation pipeline."""
24
+
25
+ def __init__(
26
+ self,
27
+ y_mean: torch.Tensor,
28
+ y_std: torch.Tensor,
29
+ x_mean: torch.Tensor,
30
+ x_std: torch.Tensor,
31
+ ) -> None:
32
+ """Standardises the spectrum and input parameters.
33
+
34
+ Args:
35
+ y_mean (float): Mean of the spectrum for standardisation.
36
+ y_std (float): Standard deviation of the spectrum for
37
+ standardisation.
38
+ x_mean (float): Mean of the input parameters for standardisation.
39
+ x_std (float): Standard deviation of the input parameters
40
+ for standardisation.
41
+
42
+ Returns:
43
+ tuple: Standardised spectrum and input parameters.
44
+ """
45
+ self.y_mean = y_mean
46
+ self.y_std = y_std
47
+ self.x_mean = x_mean
48
+ self.x_std = x_std
49
+
50
+ def forward(
51
+ self, y: torch.Tensor, x: torch.Tensor | None = None
52
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
53
+ """Standardise the spectrum and input parameters.
54
+
55
+ Args:
56
+ y (torch.Tensor): Spectrum tensor.
57
+ x (torch.Tensor, optional): Input parameters tensor.
58
+ Defaults to None.
59
+
60
+ Returns:
61
+ tuple: Standardised spectrum and input parameters.
62
+ """
63
+ y = (y - self.y_mean) / self.y_std
64
+ if x is not None:
65
+ x = (x - self.x_mean) / self.x_std
66
+ return y, x
67
+
68
+ def backward(
69
+ self, y: torch.Tensor, x: torch.Tensor | None = None
70
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
71
+ """Destandardise the spectrum and input parameters.
72
+
73
+ Args:
74
+ y (torch.Tensor): Standardised spectrum tensor.
75
+ x (torch.Tensor, optional): Standardised input parameters tensor.
76
+ Defaults to None.
77
+
78
+ Returns:
79
+ tuple: Destandardised spectrum and input parameters.
80
+ """
81
+ y = y * self.y_std + self.y_mean
82
+ if x is not None:
83
+ x = x * self.x_std + self.x_mean
84
+ return y, x
85
+
86
+
87
+ class log_base_10(NormalisationPipeline):
88
+ """Logarithm base 10 transformation for numerical stability."""
89
+
90
+ def __init__(
91
+ self,
92
+ yselector: list[int] | None = None,
93
+ xselector: list[int] | None = None,
94
+ eps: float = 1e-15,
95
+ ) -> None:
96
+ """Logarithm base 10 transformation for numerical stability.
97
+
98
+ Args:
99
+ yselector (list[int] | None): columns of the spectrum to
100
+ apply log transformation.
101
+ Assumes that the spectra are in the last dimension.
102
+ xselector (list[int] | None): columns of the input parameters to
103
+ apply log transformation.
104
+ eps (float): small value to add to avoid log(0).
105
+ """
106
+ self.yselector = yselector
107
+ self.xselector = xselector
108
+ self.eps = eps
109
+
110
+ def forward(
111
+ self, y: torch.Tensor, x: torch.Tensor | None = None
112
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
113
+ """Apply log10 transformation to selected columns.
114
+
115
+ Args:
116
+ y (torch.Tensor): Spectrum tensor.
117
+ x (torch.Tensor, optional): Input parameters tensor.
118
+ Defaults to None.
119
+
120
+ Returns:
121
+ tuple: Transformed spectrum and input parameters.
122
+ """
123
+ if x is not None:
124
+ if self.xselector is not None:
125
+ for i in self.xselector:
126
+ x[..., i] = torch.log10(x[..., i] + self.eps)
127
+ else:
128
+ x = torch.log10(x + self.eps)
129
+
130
+ if self.yselector is not None:
131
+ for i in self.yselector:
132
+ y[..., i] = torch.log10(y[..., i] + self.eps)
133
+ else:
134
+ y = torch.log10(y + self.eps)
135
+
136
+ return y, x
137
+
138
+ def backward(
139
+ self, y: torch.Tensor, x: torch.Tensor | None = None
140
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
141
+ """Apply inverse log10 transformation to selected columns.
142
+
143
+ Args:
144
+ y (torch.Tensor): Transformed spectrum tensor.
145
+ x (torch.Tensor, optional): Transformed input parameters tensor.
146
+ Defaults to None.
147
+
148
+ Returns:
149
+ tuple: Inverse transformed spectrum and input parameters.
150
+ """
151
+ if x is not None:
152
+ if self.xselector is not None:
153
+ for i in self.xselector:
154
+ x[..., i] = torch.pow(10, x[..., i]) - self.eps
155
+ else:
156
+ x = torch.pow(10, x) - self.eps
157
+
158
+ if self.yselector is not None:
159
+ for i in self.yselector:
160
+ y[..., i] = torch.pow(10, y[..., i]) - self.eps
161
+ else:
162
+ y = torch.pow(10, y) - self.eps
163
+
164
+ return y, x
@@ -0,0 +1,83 @@
1
+ """Utility functions for emu package."""
2
+
3
+ import torch
4
+
5
+
6
+ def compute_mean_std(
7
+ loader: torch.utils.data.DataLoader,
8
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
9
+ """Memory safe mean and std computation.
10
+
11
+ Args:
12
+ loader: DataLoader returning (spec, input) where:
13
+ - spec: [batch_size, 5000]
14
+ - input: [batch_size, 5000, N]
15
+
16
+ Returns:
17
+ mean_spec: [5000] - mean across batches
18
+ std_spec: [5000] - std across batches
19
+ mean_input: [N] - mean across batches and the 5000 dimension
20
+ std_input: [N] - std across batches and the 5000 dimension
21
+ """
22
+ # Accumulators
23
+ spec_sum = None
24
+ spec_sum_sq = None
25
+ input_sum = None
26
+ input_sum_sq = None
27
+ n_spec_samples = 0
28
+ n_input_samples = 0
29
+
30
+ for spec, input_data in loader:
31
+ batch_size = spec.size(0)
32
+
33
+ # === Process spec ===
34
+ # spec shape: [batch_size, 5000] -> we want stats across batch dim
35
+ if spec_sum is None:
36
+ spec_sum = torch.zeros(
37
+ spec.size(1), dtype=spec.dtype, device=spec.device
38
+ )
39
+ spec_sum_sq = torch.zeros(
40
+ spec.size(1), dtype=spec.dtype, device=spec.device
41
+ )
42
+
43
+ spec_sum += spec.sum(dim=0) # sum across batch
44
+ spec_sum_sq += (spec**2).sum(dim=0) # sum of squares across batch
45
+ n_spec_samples += batch_size
46
+
47
+ # === Process input ===
48
+ # input shape: [batch_size, 5000, N] -> we want stats across
49
+ # batch and 5000 dims
50
+ input_flat = input_data.view(
51
+ -1, input_data.size(-1)
52
+ ) # [batch_size * 5000, N]
53
+
54
+ if input_sum is None:
55
+ input_sum = torch.zeros(
56
+ input_data.size(-1),
57
+ dtype=input_data.dtype,
58
+ device=input_data.device,
59
+ )
60
+ input_sum_sq = torch.zeros(
61
+ input_data.size(-1),
62
+ dtype=input_data.dtype,
63
+ device=input_data.device,
64
+ )
65
+
66
+ input_sum += input_flat.sum(
67
+ dim=0
68
+ ) # sum across flattened batch*5000 dim
69
+ input_sum_sq += (input_flat**2).sum(dim=0) # sum of squares
70
+ n_input_samples += input_flat.size(0) # batch_size * 5000
71
+
72
+ # Compute means and stds
73
+ mean_spec = spec_sum / n_spec_samples
74
+ var_spec = (spec_sum_sq / n_spec_samples) - (mean_spec**2)
75
+ std_spec = torch.where(var_spec <= 1e-3, 1, torch.sqrt(var_spec))
76
+
77
+ mean_input = input_sum / n_input_samples
78
+ var_input = (input_sum_sq / n_input_samples) - (mean_input**2)
79
+ std_input = torch.sqrt(
80
+ torch.clamp(var_input, min=1e-8)
81
+ ) # clamp for numerical stability
82
+
83
+ return mean_spec, std_spec, mean_input, std_input
@@ -0,0 +1,59 @@
1
+ Metadata-Version: 2.4
2
+ Name: astroemu
3
+ Version: 0.1.2
4
+ Summary: Generalised framework for emulating spectral signals.
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: jax
9
+ Requires-Dist: jaxlib
10
+ Requires-Dist: torch
11
+ Provides-Extra: dev
12
+ Requires-Dist: ruff; extra == "dev"
13
+ Requires-Dist: pre-commit; extra == "dev"
14
+ Dynamic: license-file
15
+
16
+ # Next Generation Emulators for Cosmology and Astrophysics
17
+
18
+ | | |
19
+ |---| ---|
20
+ | Author| Harry Bevins|
21
+ | Version| 0.1.2 |
22
+ | Homepage | https://github.com/htjb/astroemu |
23
+
24
+ UNDER DEVELOPMENT
25
+
26
+ `astroemu` implements a generalized framework for emulating
27
+ spectral signals
28
+ and is inspired by the [`globalemu`](https://github.com/htjb/globalemu) package.
29
+
30
+ The neural network emulators are implemented in JAX and the dataloaders are
31
+ built on top of PyTorch.
32
+
33
+ As with `globalemu` the idea is to input the independent variables alongside
34
+ the physical parameters of your model then predicting a single corresponding
35
+ spectral value. Full spectra can then be generated via a vectorised call to
36
+ the network. The training data is tiled in the dataloaders so that the
37
+ parameters and independent variables are concatenated as inputs and
38
+ stacked up alongside the outputs. For example if we have a signal
39
+ $y = f(x, \theta)$ and we have N $\theta$ samples and m $x$ and $y$ values then
40
+ our training data looks like
41
+
42
+ |Input|Output|
43
+ |--|--|
44
+ |[$\theta_{0}$, $x_0$]| $y_0$ |
45
+ |[$\theta_{0}$, $x_1$]| $y_1$ |
46
+ |[$\theta_{0}$, $x_2$]| $y_2$ |
47
+ |[$\theta_{0}$, ...]|...|
48
+ |[$\theta_{0}$, $x_m$]|$y_m$|
49
+ |[..., ...]| ...|
50
+ |[$\theta_N$, $x_m$]|$y_m$|
51
+
52
+ For more details see the `globalemu` [paper](https://arxiv.org/abs/2104.04336).
53
+ A paper is in preparation demonstrating applications of this package to a broad
54
+ range of astrophysical signals.
55
+
56
+ ## Contributions
57
+
58
+ Contributions are welcome! Please open an issue to discuss and have a
59
+ read of the Contribution guidelines.
@@ -0,0 +1,14 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ astroemu/__init__.py
5
+ astroemu/_version.py
6
+ astroemu/dataloaders.py
7
+ astroemu/network.py
8
+ astroemu/normalisation.py
9
+ astroemu/utils.py
10
+ astroemu.egg-info/PKG-INFO
11
+ astroemu.egg-info/SOURCES.txt
12
+ astroemu.egg-info/dependency_links.txt
13
+ astroemu.egg-info/requires.txt
14
+ astroemu.egg-info/top_level.txt
@@ -0,0 +1,7 @@
1
+ jax
2
+ jaxlib
3
+ torch
4
+
5
+ [dev]
6
+ ruff
7
+ pre-commit
@@ -0,0 +1 @@
1
+ astroemu
@@ -0,0 +1,39 @@
1
+ [project]
2
+ name = "astroemu"
3
+ version = "0.1.2"
4
+ description = "Generalised framework for emulating spectral signals."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "jax",
9
+ "jaxlib",
10
+ "torch",
11
+ ]
12
+
13
+ [build-system]
14
+ requires = ["setuptools>=77.0", "wheel"]
15
+ build-backend = "setuptools.build_meta"
16
+
17
+ [tool.setuptools.packages.find]
18
+ where = ["."]
19
+ include = ["astroemu*"]
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "ruff",
24
+ "pre-commit",
25
+ ]
26
+
27
+ [tool.ruff]
28
+ line-length = 79
29
+
30
+ [tool.ruff.lint]
31
+ select =["E", "F", "W", # basics
32
+ "I", # isort
33
+ "D", # docstrings
34
+ "UP", # pyupgrade (includes type modernization)
35
+ "ANN", # type annotations
36
+ ]
37
+
38
+ [tool.ruff.lint.pydocstyle]
39
+ convention = "google"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+