diffstar 0.3.2__tar.gz → 0.3.3__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- diffstar-0.3.2/.github/workflows/linting.yml → diffstar-0.3.3/.github/workflows/linting.yaml +2 -3
- diffstar-0.3.2/.github/workflows/monthly-warning-test.yml → diffstar-0.3.3/.github/workflows/monthly-warning-test.yaml +2 -4
- diffstar-0.3.2/.github/workflows/test_releases.yml → diffstar-0.3.3/.github/workflows/test_releases.yaml +2 -4
- diffstar-0.3.2/.github/workflows/tests_cron.yml → diffstar-0.3.3/.github/workflows/tests_cron.yaml +2 -4
- {diffstar-0.3.2 → diffstar-0.3.3}/CHANGES.rst +5 -0
- {diffstar-0.3.2/diffstar.egg-info → diffstar-0.3.3}/PKG-INFO +4 -4
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/__init__.py +2 -1
- diffstar-0.3.3/diffstar/_version.py +1 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/load_bpl.py +2 -1
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/defaults.py +3 -2
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/fit_smah_helpers_tpeak.py +1 -2
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/fitting_kernels.py +7 -11
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/param_clippers.py +2 -1
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +7 -43
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +24 -27
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +8 -10
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_param_clippers.py +2 -1
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_stars.py +2 -1
- diffstar-0.3.3/diffstar/kernels/__init__.py +4 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/test_quenching_kernels.py +1 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_defaults.py +2 -1
- diffstar-0.3.3/diffstar/tests/test_sfh_model_tpeak.py +62 -0
- {diffstar-0.3.2 → diffstar-0.3.3/diffstar.egg-info}/PKG-INFO +4 -4
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/SOURCES.txt +4 -21
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/requires.txt +1 -1
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/index.rst +0 -1
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/rtd_environment.yaml +2 -2
- {diffstar-0.3.2 → diffstar-0.3.3}/pyproject.toml +2 -2
- diffstar-0.3.3/requirements.txt +4 -0
- diffstar-0.3.2/diffstar/_version.py +0 -1
- diffstar-0.3.2/diffstar/fitting_helpers/fit_smah_helpers.py +0 -1715
- diffstar-0.3.2/diffstar/kernels/__init__.py +0 -5
- diffstar-0.3.2/diffstar/kernels/history_kernel_builders.py +0 -268
- diffstar-0.3.2/diffstar/kernels/kernel_builders.py +0 -249
- diffstar-0.3.2/diffstar/kernels/main_sequence_kernels.py +0 -233
- diffstar-0.3.2/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -121
- diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders.py +0 -161
- diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders_tpeak.py +0 -172
- diffstar-0.3.2/diffstar/sfh.py +0 -220
- diffstar-0.3.2/diffstar/sfh_model.py +0 -140
- diffstar-0.3.2/diffstar/tests/test_gas.py +0 -40
- diffstar-0.3.2/diffstar/tests/test_lax_main_sequence.py +0 -153
- diffstar-0.3.2/diffstar/tests/test_lax_sfh.py +0 -154
- diffstar-0.3.2/diffstar/tests/test_main_sequence_kernels.py +0 -68
- diffstar-0.3.2/diffstar/tests/test_sfh.py +0 -198
- diffstar-0.3.2/diffstar/tests/test_sfh_model.py +0 -156
- diffstar-0.3.2/diffstar/tests/test_sfh_model_tpeak.py +0 -155
- diffstar-0.3.2/docs/source/demo_diffstar_fitter.ipynb +0 -337
- diffstar-0.3.2/docs/source/demos.rst +0 -8
- diffstar-0.3.2/requirements.txt +0 -4
- {diffstar-0.3.2 → diffstar-0.3.3}/.coveragerc +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/.git_archival.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/.gitattributes +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/.gitignore +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/.readthedocs.yml +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/LICENSE.rst +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/README.md +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/load_smah_data.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/test_load_smah_data.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/testing_data/subvol_000_diffmah_fits.h5 +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/diffstarnet_tdata.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/tests/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/tests/test_diffstarnet_tdata.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/stars.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/utils.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/gas_consumption.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/history_kernel_builders_tpeak.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/main_sequence_kernels_tpeak.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/quenching_kernels.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/sfh_model_tpeak.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/__init__.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_quenching.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_utils.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/utils.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/dependency_links.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/top_level.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/Makefile +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/make.bat +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/_static/README.txt +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/citation.rst +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/conf.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/demo_diffstar_sfh.ipynb +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/installation.rst +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/reference.rst +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/scripts/generate_unit_testing_data.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/scripts/history_fitting_script.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/scripts/history_fitting_script_SMDPL_tpeak.py +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/setup.cfg +0 -0
- {diffstar-0.3.2 → diffstar-0.3.3}/setup.py +0 -0
diffstar-0.3.2/.github/workflows/linting.yml → diffstar-0.3.3/.github/workflows/linting.yaml
RENAMED
@@ -21,15 +21,14 @@ jobs:
|
|
21
21
|
channel-priority: strict
|
22
22
|
show-channel-urls: true
|
23
23
|
miniforge-version: latest
|
24
|
-
miniforge-variant: Mambaforge
|
25
24
|
|
26
25
|
- name: configure conda and install code
|
27
26
|
shell: bash -l {0}
|
28
27
|
run: |
|
29
|
-
|
28
|
+
conda install --quiet \
|
30
29
|
--file=requirements.txt
|
31
30
|
python -m pip install --no-deps -e .
|
32
|
-
|
31
|
+
conda install -y -q \
|
33
32
|
flake8
|
34
33
|
|
35
34
|
- name: lint
|
@@ -23,18 +23,16 @@ jobs:
|
|
23
23
|
channel-priority: strict
|
24
24
|
show-channel-urls: true
|
25
25
|
miniforge-version: latest
|
26
|
-
miniforge-variant: Mambaforge
|
27
|
-
use-mamba: true
|
28
26
|
|
29
27
|
- name: configure conda and install code
|
30
28
|
# Test against current main branch of diffmah and dsps
|
31
29
|
shell: bash -l {0}
|
32
30
|
run: |
|
33
31
|
conda config --set always_yes yes
|
34
|
-
|
32
|
+
conda install --quiet \
|
35
33
|
--file=requirements.txt
|
36
34
|
python -m pip install --no-deps -e .
|
37
|
-
|
35
|
+
conda install -y -q \
|
38
36
|
flake8 \
|
39
37
|
pytest \
|
40
38
|
pytest-xdist \
|
@@ -27,18 +27,16 @@ jobs:
|
|
27
27
|
channel-priority: strict
|
28
28
|
show-channel-urls: true
|
29
29
|
miniforge-version: latest
|
30
|
-
miniforge-variant: Mambaforge
|
31
|
-
use-mamba: true
|
32
30
|
|
33
31
|
- name: configure conda and install code
|
34
32
|
# Test against current confa-forge release of diffmah
|
35
33
|
shell: bash -l {0}
|
36
34
|
run: |
|
37
35
|
conda config --set always_yes yes
|
38
|
-
|
36
|
+
conda install --quiet \
|
39
37
|
--file=requirements.txt
|
40
38
|
python -m pip install --no-deps -e .
|
41
|
-
|
39
|
+
conda install -y -q \
|
42
40
|
flake8 \
|
43
41
|
pytest \
|
44
42
|
pytest-xdist \
|
diffstar-0.3.2/.github/workflows/tests_cron.yml → diffstar-0.3.3/.github/workflows/tests_cron.yaml
RENAMED
@@ -27,18 +27,16 @@ jobs:
|
|
27
27
|
channel-priority: strict
|
28
28
|
show-channel-urls: true
|
29
29
|
miniforge-version: latest
|
30
|
-
miniforge-variant: Mambaforge
|
31
|
-
use-mamba: true
|
32
30
|
|
33
31
|
- name: configure conda and install code
|
34
32
|
# Test against current main branch of diffmah and dsps
|
35
33
|
shell: bash -l {0}
|
36
34
|
run: |
|
37
35
|
conda config --set always_yes yes
|
38
|
-
|
36
|
+
conda install --quiet \
|
39
37
|
--file=requirements.txt
|
40
38
|
python -m pip install --no-deps -e .
|
41
|
-
|
39
|
+
conda install -y -q \
|
42
40
|
flake8 \
|
43
41
|
pytest \
|
44
42
|
pytest-xdist \
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: diffstar
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.3
|
4
4
|
Summary: Differentiable Star Formation Histories
|
5
5
|
Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
|
6
6
|
License: BSD 3-Clause License
|
@@ -35,10 +35,10 @@ License: BSD 3-Clause License
|
|
35
35
|
|
36
36
|
Project-URL: home, https://github.com/ArgonneCPAC/diffstar
|
37
37
|
Classifier: Programming Language :: Python :: 3
|
38
|
-
Requires-Python: >=3.
|
38
|
+
Requires-Python: >=3.11
|
39
39
|
Description-Content-Type: text/markdown
|
40
40
|
License-File: LICENSE.rst
|
41
|
-
Requires-Dist: diffmah>=0.
|
41
|
+
Requires-Dist: diffmah>=0.7.0
|
42
42
|
Requires-Dist: numpy
|
43
43
|
Requires-Dist: jax
|
44
44
|
Requires-Dist: h5py
|
@@ -1,5 +1,6 @@
|
|
1
1
|
"""
|
2
2
|
"""
|
3
|
+
|
3
4
|
# flake8: noqa
|
4
5
|
from ._version import __version__
|
5
6
|
from .defaults import (
|
@@ -14,4 +15,4 @@ from .defaults import (
|
|
14
15
|
get_bounded_diffstar_params,
|
15
16
|
get_unbounded_diffstar_params,
|
16
17
|
)
|
17
|
-
from .
|
18
|
+
from .sfh_model_tpeak import calc_sfh_galpop, calc_sfh_singlegal
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = '0.3.3'
|
@@ -1,5 +1,6 @@
|
|
1
1
|
"""
|
2
2
|
"""
|
3
|
+
|
3
4
|
import os
|
4
5
|
from collections import OrderedDict
|
5
6
|
|
@@ -13,7 +14,7 @@ try:
|
|
13
14
|
except ImportError:
|
14
15
|
HAS_ASTROPY = False
|
15
16
|
|
16
|
-
from ..kernels.
|
17
|
+
from ..kernels.main_sequence_kernels_tpeak import _get_bounded_sfr_params_vmap
|
17
18
|
from ..kernels.quenching_kernels import _get_bounded_q_params_vmap
|
18
19
|
from ..utils import _jax_get_dt_array
|
19
20
|
|
@@ -1,10 +1,11 @@
|
|
1
1
|
"""
|
2
2
|
"""
|
3
|
+
|
3
4
|
# flake8: noqa
|
4
5
|
from collections import namedtuple
|
5
6
|
|
6
7
|
import numpy as np
|
7
|
-
from diffmah.
|
8
|
+
from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS, DEFAULT_MAH_PDICT
|
8
9
|
from jax import jit as jjit
|
9
10
|
|
10
11
|
TODAY = 13.8
|
@@ -20,7 +21,7 @@ DEFAULT_N_STEPS = 50
|
|
20
21
|
|
21
22
|
|
22
23
|
from .kernels.gas_consumption import FB
|
23
|
-
from .kernels.
|
24
|
+
from .kernels.main_sequence_kernels_tpeak import (
|
24
25
|
DEFAULT_MS_PARAMS,
|
25
26
|
DEFAULT_MS_PDICT,
|
26
27
|
DEFAULT_U_MS_PARAMS,
|
@@ -146,7 +146,6 @@ def get_loss_data_default(
|
|
146
146
|
log_smah_sim,
|
147
147
|
logmp,
|
148
148
|
mah_params,
|
149
|
-
t_peak,
|
150
149
|
dlogm_cut=DLOGM_CUT,
|
151
150
|
t_fit_min=T_FIT_MIN,
|
152
151
|
mass_fit_min=MIN_MASS_CUT,
|
@@ -251,7 +250,7 @@ def get_loss_data_default(
|
|
251
250
|
)
|
252
251
|
|
253
252
|
logt = jnp.log10(t_sim)
|
254
|
-
dmhdt, log_mah = _diffmah_kern(mah_params, t_sim,
|
253
|
+
dmhdt, log_mah = _diffmah_kern(mah_params, t_sim, lgt0)
|
255
254
|
|
256
255
|
weight, weight_fstar = get_weights(
|
257
256
|
t_sim,
|
@@ -1,14 +1,14 @@
|
|
1
1
|
"""
|
2
2
|
"""
|
3
3
|
|
4
|
-
from diffmah
|
4
|
+
from diffmah import DEFAULT_MAH_PARAMS, mah_singlehalo
|
5
5
|
from jax import jit as jjit
|
6
6
|
from jax import numpy as jnp
|
7
7
|
from jax import vmap
|
8
8
|
|
9
|
-
from ..defaults import FB
|
9
|
+
from ..defaults import FB, LGT0
|
10
10
|
from ..kernels.gas_consumption import _get_lagged_gas
|
11
|
-
from ..kernels.
|
11
|
+
from ..kernels.main_sequence_kernels_tpeak import (
|
12
12
|
MS_BOUNDING_SIGMOID_PDICT,
|
13
13
|
_get_bounded_sfr_params,
|
14
14
|
_sfr_eff_plaw,
|
@@ -127,13 +127,7 @@ def calculate_sm_sfr_history_from_mah(
|
|
127
127
|
|
128
128
|
@jjit
|
129
129
|
def calculate_histories(
|
130
|
-
lgt,
|
131
|
-
dt,
|
132
|
-
mah_params,
|
133
|
-
u_ms_params,
|
134
|
-
u_q_params,
|
135
|
-
fstar_tdelay,
|
136
|
-
fb=FB,
|
130
|
+
lgt, dt, mah_params, u_ms_params, u_q_params, fstar_tdelay, fb=FB, lgt0=LGT0
|
137
131
|
):
|
138
132
|
"""Calculate individual halo mass MAH and galaxy SFH
|
139
133
|
|
@@ -185,7 +179,9 @@ def calculate_histories(
|
|
185
179
|
Base-10 log of cumulative peak halo mass in units of Msun assuming h=1
|
186
180
|
|
187
181
|
"""
|
188
|
-
|
182
|
+
tarr = 10**lgt
|
183
|
+
mah_params = DEFAULT_MAH_PARAMS._make(mah_params)
|
184
|
+
dmhdt, log_mah = mah_singlehalo(mah_params, tarr, lgt0)
|
189
185
|
mstar, sfr, fstar = calculate_sm_sfr_fstar_history_from_mah(
|
190
186
|
lgt,
|
191
187
|
dt,
|
@@ -1,13 +1,14 @@
|
|
1
1
|
"""Functions ms_param_clipper and q_param_clipper implement clips on the diffstar
|
2
2
|
parameters to help protect against NaNs and infinities
|
3
3
|
"""
|
4
|
+
|
4
5
|
from collections import OrderedDict
|
5
6
|
|
6
7
|
from jax import jit as jjit
|
7
8
|
from jax import numpy as jnp
|
8
9
|
from jax import vmap
|
9
10
|
|
10
|
-
from ..kernels.
|
11
|
+
from ..kernels.main_sequence_kernels_tpeak import MS_PARAM_BOUNDS_PDICT
|
11
12
|
from ..kernels.quenching_kernels import Q_PARAM_BOUNDS_PDICT
|
12
13
|
|
13
14
|
_EPS = 0.001
|
@@ -2,14 +2,12 @@
|
|
2
2
|
"""
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
-
from diffmah
|
6
|
-
from diffmah.
|
5
|
+
from diffmah import mah_singlehalo
|
6
|
+
from diffmah.defaults import DEFAULT_MAH_PARAMS
|
7
7
|
|
8
8
|
from ...defaults import DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB, LGT0
|
9
|
-
from ...kernels import get_sfh_from_mah_kern
|
10
9
|
from ...utils import _jax_get_dt_array
|
11
10
|
from ..fitting_kernels import (
|
12
|
-
_sfr_history_from_mah,
|
13
11
|
calculate_histories,
|
14
12
|
calculate_histories_vmap,
|
15
13
|
calculate_sm_sfr_fstar_history_from_mah,
|
@@ -19,17 +17,6 @@ from ..fitting_kernels import (
|
|
19
17
|
DEFAULT_LOGM0 = 12.0
|
20
18
|
|
21
19
|
|
22
|
-
def _get_default_diffmah_args():
|
23
|
-
return (
|
24
|
-
LGT0,
|
25
|
-
DEFAULT_MAH_PARAMS.logmp,
|
26
|
-
DEFAULT_MAH_PARAMS.logtc,
|
27
|
-
MAH_K,
|
28
|
-
DEFAULT_MAH_PARAMS.early_index,
|
29
|
-
DEFAULT_MAH_PARAMS.late_index,
|
30
|
-
)
|
31
|
-
|
32
|
-
|
33
20
|
def test_calculate_sm_sfr_fstar_history_from_mah():
|
34
21
|
n_t = 100
|
35
22
|
tarr = np.linspace(0.1, 10**LGT0, n_t)
|
@@ -41,7 +28,7 @@ def test_calculate_sm_sfr_fstar_history_from_mah():
|
|
41
28
|
_mask = tarr > fstar_tdelay + fstar_tdelay / 2.0
|
42
29
|
fstar_indx_high = fstar_indx_high[_mask]
|
43
30
|
|
44
|
-
dmhdt, log_mah =
|
31
|
+
dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr, LGT0)
|
45
32
|
args = (
|
46
33
|
lgtarr,
|
47
34
|
dtarr,
|
@@ -68,7 +55,7 @@ def test_calculate_sm_sfr_history_from_mah():
|
|
68
55
|
tarr = np.linspace(0.1, 10**LGT0, n_t)
|
69
56
|
lgtarr = np.log10(tarr)
|
70
57
|
dtarr = _jax_get_dt_array(tarr)
|
71
|
-
dmhdt, log_mah =
|
58
|
+
dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr, LGT0)
|
72
59
|
|
73
60
|
args = lgtarr, dtarr, dmhdt, log_mah, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB
|
74
61
|
_res = calculate_sm_sfr_history_from_mah(*args)
|
@@ -86,7 +73,6 @@ def test_calculate_histories():
|
|
86
73
|
tarr = np.linspace(0.1, 10**LGT0, n_t)
|
87
74
|
lgtarr = np.log10(tarr)
|
88
75
|
dtarr = _jax_get_dt_array(tarr)
|
89
|
-
all_diffmah_args = _get_default_diffmah_args()
|
90
76
|
|
91
77
|
fstar_tdelay = 0.5 # gyr
|
92
78
|
fstar_indx_high = np.searchsorted(tarr, tarr - fstar_tdelay)
|
@@ -96,7 +82,7 @@ def test_calculate_histories():
|
|
96
82
|
args = (
|
97
83
|
lgtarr,
|
98
84
|
dtarr,
|
99
|
-
|
85
|
+
DEFAULT_MAH_PARAMS,
|
100
86
|
DEFAULT_U_MS_PARAMS,
|
101
87
|
DEFAULT_U_Q_PARAMS,
|
102
88
|
fstar_tdelay,
|
@@ -112,16 +98,12 @@ def test_calculate_histories():
|
|
112
98
|
assert x.shape == (n_t,)
|
113
99
|
assert np.all(np.diff(log_mah) > 0)
|
114
100
|
|
115
|
-
logmp = all_diffmah_args[1]
|
116
|
-
assert log_mah[-1] == logmp
|
117
|
-
|
118
101
|
|
119
102
|
def test_calculate_histories_vmap():
|
120
103
|
n_t = 100
|
121
104
|
tarr = np.linspace(0.1, 10**LGT0, n_t)
|
122
105
|
lgtarr = np.log10(tarr)
|
123
106
|
dtarr = _jax_get_dt_array(tarr)
|
124
|
-
all_diffmah_args = np.array(_get_default_diffmah_args()).reshape((1, -1))
|
125
107
|
u_ms_params = np.array(DEFAULT_U_MS_PARAMS).reshape((1, -1))
|
126
108
|
u_q_params = np.array(DEFAULT_U_Q_PARAMS).reshape((1, -1))
|
127
109
|
fstar_tdelay = 0.5 # gyr
|
@@ -129,11 +111,12 @@ def test_calculate_histories_vmap():
|
|
129
111
|
_mask = tarr > fstar_tdelay + fstar_tdelay / 2.0
|
130
112
|
fstar_indx_high = fstar_indx_high[_mask]
|
131
113
|
|
114
|
+
mah_params = DEFAULT_MAH_PARAMS._make([np.zeros(1) + x for x in DEFAULT_MAH_PARAMS])
|
132
115
|
# in_axes = (None, None, 0, 0, 0, None, None, None)
|
133
116
|
args = (
|
134
117
|
lgtarr,
|
135
118
|
dtarr,
|
136
|
-
|
119
|
+
mah_params,
|
137
120
|
u_ms_params,
|
138
121
|
u_q_params,
|
139
122
|
fstar_tdelay,
|
@@ -145,22 +128,3 @@ def test_calculate_histories_vmap():
|
|
145
128
|
mstar_galpop, sfr_galpop, fstar_galpop, dmhdt_galpop, log_mah_galpop = _res
|
146
129
|
for x in mstar_galpop, sfr_galpop, dmhdt_galpop, log_mah_galpop:
|
147
130
|
assert x.shape == (1, n_t)
|
148
|
-
|
149
|
-
|
150
|
-
def test_sfr_history_from_mah():
|
151
|
-
n_t = 200
|
152
|
-
tarr = np.linspace(0.1, 10**LGT0, n_t)
|
153
|
-
lgtarr = np.log10(tarr)
|
154
|
-
dtarr = _jax_get_dt_array(tarr)
|
155
|
-
all_diffmah_args = _get_default_diffmah_args()
|
156
|
-
dmhdt, log_mah = _calc_halo_history(lgtarr, *all_diffmah_args)
|
157
|
-
args = lgtarr, dtarr, dmhdt, log_mah, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB
|
158
|
-
sfh_from_fitting_kernels = _sfr_history_from_mah(*args)
|
159
|
-
lgt0, logmp, logtc, k, early_index, late_index = all_diffmah_args
|
160
|
-
mah_params = logmp, logtc, early_index, late_index
|
161
|
-
|
162
|
-
sfh_kern = get_sfh_from_mah_kern(tobs_loop="vmap")
|
163
|
-
sfh_from_diffstar_kernels = sfh_kern(
|
164
|
-
tarr, mah_params, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, LGT0, FB
|
165
|
-
)
|
166
|
-
assert np.allclose(sfh_from_fitting_kernels, sfh_from_diffstar_kernels, atol=0.01)
|
{diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py
RENAMED
@@ -1,10 +1,11 @@
|
|
1
1
|
"""Unit tests enforcing that the behavior of Diffstar on the default params is frozen.
|
2
2
|
"""
|
3
|
+
|
3
4
|
import os
|
4
5
|
|
5
6
|
import numpy as np
|
7
|
+
from diffmah import mah_singlehalo
|
6
8
|
from diffmah.defaults import DEFAULT_MAH_PARAMS, MAH_K
|
7
|
-
from diffmah.individual_halo_assembly import _calc_halo_history
|
8
9
|
from jax import numpy as jnp
|
9
10
|
|
10
11
|
from ...defaults import DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, LGT0
|
@@ -19,18 +20,6 @@ TESTING_DATA_DRN = os.path.join(
|
|
19
20
|
)
|
20
21
|
|
21
22
|
|
22
|
-
def _get_default_mah_params():
|
23
|
-
"""Return (logt0, logmp, logtc, k, early, late)"""
|
24
|
-
return (
|
25
|
-
LGT0,
|
26
|
-
DEFAULT_MAH_PARAMS.logmp,
|
27
|
-
DEFAULT_MAH_PARAMS.logtc,
|
28
|
-
MAH_K,
|
29
|
-
DEFAULT_MAH_PARAMS.early_index,
|
30
|
-
DEFAULT_MAH_PARAMS.late_index,
|
31
|
-
)
|
32
|
-
|
33
|
-
|
34
23
|
def _get_default_sfr_u_params():
|
35
24
|
u_ms_params = jnp.array(DEFAULT_U_MS_PARAMS)
|
36
25
|
u_q_params = jnp.array(DEFAULT_U_Q_PARAMS)
|
@@ -43,11 +32,10 @@ def calc_sfh_on_default_params(n_t=100):
|
|
43
32
|
This function is used to generate the unit-testing data used in this module
|
44
33
|
to freeze the behavior of Diffstar evaluated on the default parameters.
|
45
34
|
"""
|
46
|
-
mah_params = _get_default_mah_params()
|
47
35
|
|
48
36
|
lgt = jnp.linspace(-1, LGT0, n_t)
|
49
37
|
dt = _get_dt_array(10**lgt)
|
50
|
-
dmhdt, log_mah =
|
38
|
+
dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, 10**lgt, LGT0)
|
51
39
|
u_ms_params, u_q_params = _get_default_sfr_u_params()
|
52
40
|
args = lgt, dt, dmhdt, log_mah, u_ms_params, u_q_params
|
53
41
|
sfh = _sfr_history_from_mah(*args)
|
@@ -120,7 +108,14 @@ def test_diffmah_behavior_is_frozen():
|
|
120
108
|
(at which point this test will need to be updated).
|
121
109
|
|
122
110
|
"""
|
123
|
-
assumed_default_params =
|
111
|
+
assumed_default_params = (
|
112
|
+
LGT0,
|
113
|
+
DEFAULT_MAH_PARAMS.logm0,
|
114
|
+
DEFAULT_MAH_PARAMS.logtc,
|
115
|
+
MAH_K,
|
116
|
+
DEFAULT_MAH_PARAMS.early_index,
|
117
|
+
DEFAULT_MAH_PARAMS.late_index,
|
118
|
+
)
|
124
119
|
lgt0, logmp, logtc, k, early_index, late_index = assumed_default_params
|
125
120
|
|
126
121
|
msg = "Default age of the universe assumed by Diffmah has changed"
|
@@ -144,7 +139,7 @@ def test_diffmah_behavior_is_frozen():
|
|
144
139
|
|
145
140
|
args, sfh = calc_sfh_on_default_params()
|
146
141
|
lgt, dt, dmhdt, log_mah, u_ms_params, u_q_params = args
|
147
|
-
dmhdt, log_mah =
|
142
|
+
dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, 10**lgt, lgt0)
|
148
143
|
|
149
144
|
log_mah_fn = os.path.join(TESTING_DATA_DRN, "default_params_test_log_mah.txt")
|
150
145
|
frozen_log_mah = np.loadtxt(log_mah_fn)
|
@@ -171,26 +166,28 @@ def test_sfh_is_frozen_on_example_bpl_sample():
|
|
171
166
|
frozen_sfhs = np.loadtxt(sfh_fn)
|
172
167
|
lgt_bpl = np.loadtxt(lgt_fn)
|
173
168
|
dt_bpl = np.loadtxt(dt_fn)
|
169
|
+
t_bpl = 10**lgt_bpl
|
170
|
+
lgt0 = lgt_bpl[-1]
|
171
|
+
assert np.allclose(LGT0_BPL, lgt0, rtol=1e-3)
|
172
|
+
|
174
173
|
mah_params_test_sample = np.loadtxt(mah_params_fn)
|
175
174
|
ms_u_params_test_sample = np.loadtxt(ms_params_fn)
|
176
175
|
q_u_params_test_sample = np.loadtxt(q_params_fn)
|
177
176
|
|
178
177
|
sfh_test_sample = []
|
179
178
|
for ih in range(mah_params_test_sample.shape[0]):
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
mah_params_test_sample[ih, 3],
|
188
|
-
)
|
179
|
+
logm0 = mah_params_test_sample[ih, 0]
|
180
|
+
logtc = mah_params_test_sample[ih, 1]
|
181
|
+
early_indx = mah_params_test_sample[ih, 2]
|
182
|
+
late_indx = mah_params_test_sample[ih, 3]
|
183
|
+
t_peak = DEFAULT_MAH_PARAMS.t_peak
|
184
|
+
mah_params = DEFAULT_MAH_PARAMS._make(
|
185
|
+
[logm0, logtc, early_indx, late_indx, t_peak]
|
189
186
|
)
|
190
187
|
ms_u_params_ih = np.array(ms_u_params_test_sample[ih, :])
|
191
188
|
q_u_params_ih = np.array(q_u_params_test_sample[ih, :])
|
192
189
|
|
193
|
-
dmhdt_ih, log_mah_ih =
|
190
|
+
dmhdt_ih, log_mah_ih = mah_singlehalo(mah_params, t_bpl, lgt0)
|
194
191
|
sfh_ih = _sfr_history_from_mah(
|
195
192
|
lgt_bpl, dt_bpl, dmhdt_ih, log_mah_ih, ms_u_params_ih, q_u_params_ih
|
196
193
|
)
|
{diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py
RENAMED
@@ -2,33 +2,33 @@
|
|
2
2
|
"""
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
+
from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
|
5
6
|
|
6
7
|
from ...defaults import DEFAULT_MS_PDICT, DEFAULT_Q_PDICT
|
7
8
|
from ...utils import _jax_get_dt_array
|
8
|
-
from ..
|
9
|
+
from ..fit_smah_helpers_tpeak import get_header, get_loss_data_default
|
9
10
|
|
10
11
|
DIFFMAH_K = 3.5
|
11
12
|
|
12
13
|
|
13
14
|
def test_get_header_colnames_agree_with_model_param_names():
|
14
|
-
header = get_header()
|
15
|
+
header, colnames = get_header()
|
15
16
|
assert header[0] == "#"
|
16
|
-
colnames = header[1:].strip().split()
|
17
17
|
|
18
18
|
assert colnames[0] == "halo_id"
|
19
19
|
|
20
20
|
u_ms_colnames_from_header = colnames[1:6]
|
21
|
-
ms_colnames_from_header = [s
|
21
|
+
ms_colnames_from_header = [s for s in u_ms_colnames_from_header]
|
22
22
|
assert ms_colnames_from_header == list(DEFAULT_MS_PDICT.keys())
|
23
23
|
|
24
24
|
u_q_colnames_from_header = colnames[6:10]
|
25
|
-
q_colnames_from_header = [s
|
25
|
+
q_colnames_from_header = [s for s in u_q_colnames_from_header]
|
26
26
|
assert q_colnames_from_header == list(DEFAULT_Q_PDICT.keys())
|
27
27
|
|
28
28
|
assert colnames[10:] == ["loss", "success"]
|
29
29
|
|
30
30
|
|
31
|
-
def
|
31
|
+
def test_get_loss_data_evaluates():
|
32
32
|
t_sim = np.linspace(0.1, 13.8, 100)
|
33
33
|
dt_sim = _jax_get_dt_array(t_sim)
|
34
34
|
sfrh = np.random.uniform(0, 10, t_sim.size)
|
@@ -36,8 +36,6 @@ def test_get_loss_data_fixed_hi():
|
|
36
36
|
log_smah_sim = np.log10(smh)
|
37
37
|
|
38
38
|
logmp = 12.0
|
39
|
-
|
40
|
-
|
41
|
-
p_init, loss_data = get_loss_data_fixed_hi(
|
42
|
-
t_sim, dt_sim, sfrh, log_smah_sim, logmp, mah_params
|
39
|
+
p_init, loss_data = get_loss_data_default(
|
40
|
+
t_sim, dt_sim, sfrh, log_smah_sim, logmp, DEFAULT_MAH_PARAMS
|
43
41
|
)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
"""
|
2
2
|
"""
|
3
|
+
|
3
4
|
import numpy as np
|
4
5
|
from jax import random as jran
|
5
6
|
|
6
|
-
from ...kernels.
|
7
|
+
from ...kernels.main_sequence_kernels_tpeak import (
|
7
8
|
MS_PARAM_BOUNDS_PDICT,
|
8
9
|
_get_bounded_sfr_params_vmap,
|
9
10
|
_get_unbounded_sfr_params_vmap,
|
@@ -0,0 +1,62 @@
|
|
1
|
+
""""""
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
|
5
|
+
from jax import random as jran
|
6
|
+
|
7
|
+
from ..defaults import (
|
8
|
+
DEFAULT_MS_PARAMS,
|
9
|
+
DEFAULT_Q_PARAMS,
|
10
|
+
DEFAULT_U_MS_PARAMS,
|
11
|
+
DEFAULT_U_Q_PARAMS,
|
12
|
+
FB,
|
13
|
+
LGT0,
|
14
|
+
DiffstarUParams,
|
15
|
+
MSUParams,
|
16
|
+
QUParams,
|
17
|
+
get_bounded_diffstar_params,
|
18
|
+
)
|
19
|
+
from ..sfh_model_tpeak import calc_sfh_singlegal
|
20
|
+
|
21
|
+
|
22
|
+
def _get_all_default_params():
|
23
|
+
ms_params, q_params = DEFAULT_MS_PARAMS, DEFAULT_Q_PARAMS
|
24
|
+
return LGT0, DEFAULT_MAH_PARAMS, ms_params, q_params
|
25
|
+
|
26
|
+
|
27
|
+
def _get_all_default_u_params():
|
28
|
+
u_ms_params, u_q_params = DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS
|
29
|
+
return LGT0, DEFAULT_MAH_PARAMS, u_ms_params, u_q_params
|
30
|
+
|
31
|
+
|
32
|
+
def test_calc_sfh_singlegal_imports_from_top_level():
|
33
|
+
from .. import calc_sfh_singlegal as _func # noqa
|
34
|
+
|
35
|
+
|
36
|
+
def test_calc_sfh_galpop_imports_from_top_level():
|
37
|
+
from .. import calc_sfh_galpop as _func # noqa
|
38
|
+
|
39
|
+
|
40
|
+
def test_sfh_singlegal_evaluates_on_wide_param_range():
|
41
|
+
lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params()
|
42
|
+
|
43
|
+
n_t = 100
|
44
|
+
tarr = np.linspace(0.1, 10**lgt0, n_t)
|
45
|
+
|
46
|
+
ran_key = jran.PRNGKey(0)
|
47
|
+
ntests = 20
|
48
|
+
ran_keys = jran.split(ran_key, ntests)
|
49
|
+
for test_key in ran_keys:
|
50
|
+
ms_key, q_key = jran.split(test_key, 2)
|
51
|
+
u_ms_params = jran.normal(ms_key, shape=(5,)) + np.array(u_ms_params_init)
|
52
|
+
u_q_params = jran.normal(q_key, shape=(4,)) + np.array(u_q_params_init)
|
53
|
+
sfh_u_params = DiffstarUParams(MSUParams(*u_ms_params), QUParams(*u_q_params))
|
54
|
+
sfh_params = get_bounded_diffstar_params(sfh_u_params)
|
55
|
+
sfh_new = calc_sfh_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)
|
56
|
+
assert np.all(np.isfinite(sfh_new))
|
57
|
+
|
58
|
+
sfh_new2, smh_new2 = calc_sfh_singlegal(
|
59
|
+
sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB, return_smh=True
|
60
|
+
)
|
61
|
+
assert np.allclose(sfh_new, sfh_new2)
|
62
|
+
assert np.all(np.isfinite(smh_new2))
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: diffstar
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.3
|
4
4
|
Summary: Differentiable Star Formation Histories
|
5
5
|
Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
|
6
6
|
License: BSD 3-Clause License
|
@@ -35,10 +35,10 @@ License: BSD 3-Clause License
|
|
35
35
|
|
36
36
|
Project-URL: home, https://github.com/ArgonneCPAC/diffstar
|
37
37
|
Classifier: Programming Language :: Python :: 3
|
38
|
-
Requires-Python: >=3.
|
38
|
+
Requires-Python: >=3.11
|
39
39
|
Description-Content-Type: text/markdown
|
40
40
|
License-File: LICENSE.rst
|
41
|
-
Requires-Dist: diffmah>=0.
|
41
|
+
Requires-Dist: diffmah>=0.7.0
|
42
42
|
Requires-Dist: numpy
|
43
43
|
Requires-Dist: jax
|
44
44
|
Requires-Dist: h5py
|