diffstar 0.2.4__tar.gz → 0.3.1__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {diffstar-0.2.4 → diffstar-0.3.1}/.github/workflows/linting.yml +1 -1
- diffstar-0.3.1/.github/workflows/monthly-warning-test.yml +55 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/.github/workflows/test_releases.yml +1 -1
- {diffstar-0.2.4 → diffstar-0.3.1}/.github/workflows/tests_cron.yml +3 -2
- {diffstar-0.2.4 → diffstar-0.3.1}/CHANGES.rst +10 -0
- {diffstar-0.2.4/diffstar.egg-info → diffstar-0.3.1}/PKG-INFO +2 -2
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/__init__.py +5 -2
- diffstar-0.3.1/diffstar/_version.py +1 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/defaults.py +44 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/gas_consumption.py +6 -0
- diffstar-0.3.1/diffstar/kernels/history_kernel_builders.py +268 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/main_sequence_kernels.py +36 -2
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/quenching_kernels.py +3 -0
- diffstar-0.3.1/diffstar/kernels/tests/test_kernel_builders.py +161 -0
- diffstar-0.3.1/diffstar/sfh_model.py +140 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_defaults.py +26 -0
- diffstar-0.3.1/diffstar/tests/test_main_sequence_kernels.py +28 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_sfh.py +1 -1
- diffstar-0.3.1/diffstar/tests/test_sfh_model.py +156 -0
- diffstar-0.3.1/diffstar/tests/test_utils.py +72 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/utils.py +80 -2
- {diffstar-0.2.4 → diffstar-0.3.1/diffstar.egg-info}/PKG-INFO +2 -2
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/SOURCES.txt +5 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/demo_diffstar_fitter.ipynb +8 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/demo_diffstar_sfh.ipynb +21 -25
- {diffstar-0.2.4 → diffstar-0.3.1}/pyproject.toml +1 -1
- diffstar-0.2.4/diffstar/_version.py +0 -1
- diffstar-0.2.4/diffstar/kernels/tests/test_kernel_builders.py +0 -8
- diffstar-0.2.4/diffstar/tests/test_utils.py +0 -24
- {diffstar-0.2.4 → diffstar-0.3.1}/.coveragerc +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/.git_archival.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/.gitattributes +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/.gitignore +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/.readthedocs.yml +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/LICENSE.rst +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/README.md +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/load_bpl.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/load_smah_data.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/tests/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/fit_smah_helpers.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/fitting_kernels.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/param_clippers.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/stars.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_param_clippers.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_stars.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/utils.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/kernel_builders.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/test_quenching_kernels.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/sfh.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/__init__.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_gas.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_lax_main_sequence.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_lax_sfh.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_quenching.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/dependency_links.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/requires.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/top_level.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/Makefile +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/make.bat +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/_static/README.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/citation.rst +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/conf.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/demos.rst +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/index.rst +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/installation.rst +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/reference.rst +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/rtd_environment.yaml +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/requirements.txt +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/scripts/generate_unit_testing_data.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/scripts/history_fitting_script.py +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/setup.cfg +0 -0
- {diffstar-0.2.4 → diffstar-0.3.1}/setup.py +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
name: Test for Warnings
|
2
|
+
|
3
|
+
on:
|
4
|
+
workflow_dispatch: null
|
5
|
+
schedule:
|
6
|
+
# Runs "First of every month at 3:15am Central"
|
7
|
+
- cron: '15 8 1 * *'
|
8
|
+
|
9
|
+
jobs:
|
10
|
+
tests:
|
11
|
+
name: tests
|
12
|
+
runs-on: "ubuntu-latest"
|
13
|
+
|
14
|
+
steps:
|
15
|
+
- uses: actions/checkout@v2
|
16
|
+
with:
|
17
|
+
fetch-depth: 0
|
18
|
+
|
19
|
+
- uses: conda-incubator/setup-miniconda@v2
|
20
|
+
with:
|
21
|
+
python-version: 3.11
|
22
|
+
channels: conda-forge,defaults
|
23
|
+
channel-priority: strict
|
24
|
+
show-channel-urls: true
|
25
|
+
miniforge-version: latest
|
26
|
+
miniforge-variant: Mambaforge
|
27
|
+
use-mamba: true
|
28
|
+
|
29
|
+
- name: configure conda and install code
|
30
|
+
# Test against current main branch of diffmah and dsps
|
31
|
+
shell: bash -l {0}
|
32
|
+
run: |
|
33
|
+
conda config --set always_yes yes
|
34
|
+
mamba install --quiet \
|
35
|
+
--file=requirements.txt
|
36
|
+
python -m pip install --no-deps -e .
|
37
|
+
mamba install -y -q \
|
38
|
+
flake8 \
|
39
|
+
pytest \
|
40
|
+
pytest-xdist \
|
41
|
+
pytest-cov \
|
42
|
+
pip \
|
43
|
+
setuptools \
|
44
|
+
"setuptools_scm>=7,<8" \
|
45
|
+
python-build
|
46
|
+
pip uninstall diffmah --yes
|
47
|
+
pip install --no-deps git+https://github.com/ArgonneCPAC/diffmah.git
|
48
|
+
pip install --no-deps git+https://github.com/ArgonneCPAC/dsps.git
|
49
|
+
python -m pip install --no-build-isolation --no-deps -e .
|
50
|
+
|
51
|
+
- name: test that no warnings are raised
|
52
|
+
shell: bash -l {0}
|
53
|
+
run: |
|
54
|
+
export PYTHONWARNINGS=error
|
55
|
+
pytest -v diffstar --cov --cov-report=xml
|
@@ -22,7 +22,7 @@ jobs:
|
|
22
22
|
|
23
23
|
- uses: conda-incubator/setup-miniconda@v2
|
24
24
|
with:
|
25
|
-
python-version: 3.
|
25
|
+
python-version: 3.11
|
26
26
|
channels: conda-forge,defaults
|
27
27
|
channel-priority: strict
|
28
28
|
show-channel-urls: true
|
@@ -31,7 +31,7 @@ jobs:
|
|
31
31
|
use-mamba: true
|
32
32
|
|
33
33
|
- name: configure conda and install code
|
34
|
-
# Test against current main branch of diffmah
|
34
|
+
# Test against current main branch of diffmah and dsps
|
35
35
|
shell: bash -l {0}
|
36
36
|
run: |
|
37
37
|
conda config --set always_yes yes
|
@@ -49,6 +49,7 @@ jobs:
|
|
49
49
|
python-build
|
50
50
|
pip uninstall diffmah --yes
|
51
51
|
pip install --no-deps git+https://github.com/ArgonneCPAC/diffmah.git
|
52
|
+
pip install --no-deps git+https://github.com/ArgonneCPAC/dsps.git
|
52
53
|
python -m pip install --no-build-isolation --no-deps -e .
|
53
54
|
|
54
55
|
- name: test
|
@@ -1,3 +1,13 @@
|
|
1
|
+
0.3.1 (2024-6-19)
|
2
|
+
------------------
|
3
|
+
- Performance improvements for calc_sfh_galpop and calc_sfh_singlegal
|
4
|
+
|
5
|
+
|
6
|
+
0.3.0 (2024-01-17)
|
7
|
+
------------------
|
8
|
+
- Implement new API for primary user-facing functions calc_sfh_galpop and calc_sfh_singlegal
|
9
|
+
|
10
|
+
|
1
11
|
0.2.4 (2024-01-16)
|
2
12
|
------------------
|
3
13
|
- Require diffmah>=0.5.0
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: diffstar
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.1
|
4
4
|
Summary: Differentiable Star Formation Histories
|
5
5
|
Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
|
6
6
|
License: BSD 3-Clause License
|
@@ -35,7 +35,7 @@ License: BSD 3-Clause License
|
|
35
35
|
|
36
36
|
Project-URL: home, https://github.com/ArgonneCPAC/diffstar
|
37
37
|
Classifier: Programming Language :: Python :: 3
|
38
|
-
Requires-Python: >=3.
|
38
|
+
Requires-Python: >=3.9
|
39
39
|
Description-Content-Type: text/markdown
|
40
40
|
License-File: LICENSE.rst
|
41
41
|
Requires-Dist: diffmah>=0.5.0
|
@@ -7,8 +7,11 @@ from .defaults import (
|
|
7
7
|
DEFAULT_DIFFSTAR_U_PARAMS,
|
8
8
|
DiffstarParams,
|
9
9
|
DiffstarUParams,
|
10
|
+
MSParams,
|
11
|
+
MSUParams,
|
12
|
+
QParams,
|
13
|
+
QUParams,
|
10
14
|
get_bounded_diffstar_params,
|
11
15
|
get_unbounded_diffstar_params,
|
12
16
|
)
|
13
|
-
from .
|
14
|
-
from .sfh import sfh_galpop, sfh_singlegal
|
17
|
+
from .sfh_model import calc_sfh_galpop, calc_sfh_singlegal
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = '0.3.1'
|
@@ -14,6 +14,7 @@ LGT0 = np.log10(TODAY)
|
|
14
14
|
# Constants related to SFH integrals
|
15
15
|
SFR_MIN = 1e-14
|
16
16
|
T_BIRTH_MIN = 0.001
|
17
|
+
T_TABLE_MIN = 0.01
|
17
18
|
N_T_LGSM_INTEGRATION = 100
|
18
19
|
DEFAULT_N_STEPS = 50
|
19
20
|
|
@@ -53,6 +54,28 @@ DEFAULT_DIFFSTAR_U_PARAMS = DiffstarUParams(DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PAR
|
|
53
54
|
|
54
55
|
@jjit
|
55
56
|
def get_bounded_diffstar_params(diffstar_u_params):
|
57
|
+
"""Calculate diffstar parameters from unbounded counterparts.
|
58
|
+
|
59
|
+
The returned diffstar_params is the input expected by diffstar.calc_sfh_singlegal
|
60
|
+
and diffstar.calc_sfh_galpop.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
diffstar_u_params : namedtuple, length 2
|
65
|
+
DiffstarUParams = u_ms_params, u_q_params
|
66
|
+
u_ms_params and u_q_params are tuples of floats or ndarrays
|
67
|
+
u_ms_params = u_lgmcrit, u_lgy_at_mcrit, u_indx_lo, u_indx_hi, u_tau_dep
|
68
|
+
u_q_params = u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv
|
69
|
+
|
70
|
+
Returns
|
71
|
+
-------
|
72
|
+
diffstar_params : namedtuple, length 2
|
73
|
+
DiffstarParams = ms_params, q_params
|
74
|
+
ms_params and q_params are tuples of floats or ndarrays
|
75
|
+
ms_params = lgmcrit, lgy_at_mcrit, indx_lo, indx_hi, tau_dep
|
76
|
+
q_params = lg_qt, qlglgdt, lg_drop, lg_rejuv
|
77
|
+
|
78
|
+
"""
|
56
79
|
ms_params = MSParams(*_get_bounded_sfr_params(*diffstar_u_params.u_ms_params))
|
57
80
|
q_params = QParams(*_get_bounded_q_params(*diffstar_u_params.u_q_params))
|
58
81
|
return DiffstarParams(ms_params, q_params)
|
@@ -60,6 +83,27 @@ def get_bounded_diffstar_params(diffstar_u_params):
|
|
60
83
|
|
61
84
|
@jjit
|
62
85
|
def get_unbounded_diffstar_params(diffstar_params):
|
86
|
+
"""Calculate unbounded diffstar parameters from standard params.
|
87
|
+
|
88
|
+
This is the inverse function to get_bounded_diffstar_params
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
diffstar_params : namedtuple, length 2
|
93
|
+
DiffstarParams = ms_params, q_params
|
94
|
+
ms_params and q_params are tuples of floats or ndarrays
|
95
|
+
ms_params = lgmcrit, lgy_at_mcrit, indx_lo, indx_hi, tau_dep
|
96
|
+
q_params = lg_qt, qlglgdt, lg_drop, lg_rejuv
|
97
|
+
|
98
|
+
Returns
|
99
|
+
-------
|
100
|
+
diffstar_u_params : namedtuple, length 2
|
101
|
+
DiffstarUParams = u_ms_params, u_q_params
|
102
|
+
u_ms_params and u_q_params are tuples of floats or ndarrays
|
103
|
+
u_ms_params = u_lgmcrit, u_lgy_at_mcrit, u_indx_lo, u_indx_hi, u_tau_dep
|
104
|
+
u_q_params = u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv
|
105
|
+
|
106
|
+
"""
|
63
107
|
u_ms_params = MSUParams(*_get_unbounded_sfr_params(*diffstar_params.ms_params))
|
64
108
|
u_q_params = QUParams(*_get_unbounded_q_params(*diffstar_params.q_params))
|
65
109
|
return DiffstarUParams(u_ms_params, u_q_params)
|
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Calculate mass of gas available for star formation from the freshly accreted gas.
|
2
2
|
"""
|
3
|
+
|
3
4
|
from jax import jit as jjit
|
4
5
|
from jax import lax
|
5
6
|
from jax import numpy as jnp
|
@@ -28,6 +29,11 @@ def _gas_conversion_kern(t_form, t_acc, dt, tau_dep, tau_dep_max):
|
|
28
29
|
return tri_kern
|
29
30
|
|
30
31
|
|
32
|
+
_vmap_gas_conversion_kern = jjit(
|
33
|
+
vmap(_gas_conversion_kern, in_axes=(None, 0, None, None, None))
|
34
|
+
)
|
35
|
+
|
36
|
+
|
31
37
|
_a, _b = (0, None, 0, None, None), (None, 0, None, None, None)
|
32
38
|
_depletion_kernel = jjit(vmap(vmap(_gas_conversion_kern, in_axes=_b), in_axes=_a))
|
33
39
|
|
@@ -0,0 +1,268 @@
|
|
1
|
+
"""
|
2
|
+
"""
|
3
|
+
|
4
|
+
from jax import jit as jjit
|
5
|
+
from jax import lax
|
6
|
+
from jax import numpy as jnp
|
7
|
+
from jax import vmap
|
8
|
+
|
9
|
+
from ..defaults import (
|
10
|
+
DEFAULT_DIFFSTAR_PARAMS,
|
11
|
+
DEFAULT_MAH_PARAMS,
|
12
|
+
DEFAULT_N_STEPS,
|
13
|
+
SFR_MIN,
|
14
|
+
T_BIRTH_MIN,
|
15
|
+
)
|
16
|
+
from .main_sequence_kernels import (
|
17
|
+
_lax_ms_sfh_scalar_kern_scan,
|
18
|
+
_lax_ms_sfh_scalar_kern_sum,
|
19
|
+
)
|
20
|
+
from .quenching_kernels import _quenching_kern
|
21
|
+
|
22
|
+
__all__ = ("build_sfh_from_mah_kernel",)
|
23
|
+
|
24
|
+
N_MAH_PARAMS = len(DEFAULT_MAH_PARAMS)
|
25
|
+
N_MS_PARAMS = len(DEFAULT_DIFFSTAR_PARAMS.ms_params)
|
26
|
+
N_Q_PARAMS = len(DEFAULT_DIFFSTAR_PARAMS.q_params)
|
27
|
+
|
28
|
+
|
29
|
+
def build_sfh_from_mah_kernel(
|
30
|
+
n_steps=DEFAULT_N_STEPS,
|
31
|
+
tacc_integration_min=T_BIRTH_MIN,
|
32
|
+
tobs_loop=None,
|
33
|
+
galpop_loop=None,
|
34
|
+
tform_loop="sum",
|
35
|
+
):
|
36
|
+
"""Build a JAX-jitted kernel to calculate SFHs of a galaxy population.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
n_steps : int, optional
|
41
|
+
Number of timesteps to use in the tacc integration
|
42
|
+
|
43
|
+
tacc_integration_min : float, optional
|
44
|
+
Earliest time to use in the tacc integrations. Default is 0.01 Gyr.
|
45
|
+
|
46
|
+
tobs_loop : string, optional
|
47
|
+
Argument specifies whether the input time of observation is a scalar or array
|
48
|
+
Default argument is None, for a JAX kernel that assumes scalar input for tobs
|
49
|
+
For a JAX kernel that assumes an array input for tobs,
|
50
|
+
options are either 'vmap' or 'scan', specifying the calculation method
|
51
|
+
|
52
|
+
galpop_loop : string, optional
|
53
|
+
Argument specifies whether the input galaxy/halo parameters assumed by the
|
54
|
+
returned JAX kernel pertain to a single galaxy or a population.
|
55
|
+
Default argument is None, for a single-galaxy JAX kernel
|
56
|
+
For a JAX kernel that assumes galaxy population,
|
57
|
+
options are either 'vmap' or 'scan', specifying the calculation method
|
58
|
+
|
59
|
+
tform_loop : string
|
60
|
+
Use 'sum' for faster vmap-based calculation and 'scan' for slower alternative.
|
61
|
+
Default is 'sum'
|
62
|
+
|
63
|
+
Returns
|
64
|
+
-------
|
65
|
+
sfh_from_mah_kern : function
|
66
|
+
JAX-jitted function that calculates SFH in accord with the input arguments
|
67
|
+
Function signature is as follows:
|
68
|
+
|
69
|
+
def sfh_from_mah_kern(t, mah_params, ms_params, q_params, lgt0, fb):
|
70
|
+
return sfh
|
71
|
+
|
72
|
+
"""
|
73
|
+
if tform_loop == "sum":
|
74
|
+
_lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_sum
|
75
|
+
elif tform_loop == "scan":
|
76
|
+
_lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_scan
|
77
|
+
|
78
|
+
uniform_table = jnp.linspace(0, 1, n_steps)
|
79
|
+
|
80
|
+
@jjit
|
81
|
+
def _kern(
|
82
|
+
t_form,
|
83
|
+
logmp,
|
84
|
+
logtc,
|
85
|
+
early_index,
|
86
|
+
late_index,
|
87
|
+
lgmcrit,
|
88
|
+
lgy_at_mcrit,
|
89
|
+
indx_lo,
|
90
|
+
indx_hi,
|
91
|
+
tau_dep,
|
92
|
+
lg_qt,
|
93
|
+
qlglgdt,
|
94
|
+
lg_drop,
|
95
|
+
lg_rejuv,
|
96
|
+
lgt0,
|
97
|
+
fb,
|
98
|
+
):
|
99
|
+
mah_params = logmp, logtc, early_index, late_index
|
100
|
+
ms_params = lgmcrit, lgy_at_mcrit, indx_lo, indx_hi, tau_dep
|
101
|
+
|
102
|
+
t_min = jnp.max(jnp.array((tacc_integration_min, t_form - tau_dep)))
|
103
|
+
t_table = t_min + uniform_table * (t_form - t_min)
|
104
|
+
args = t_form, mah_params, ms_params, lgt0, fb, t_table
|
105
|
+
ms_sfr = _lax_ms_sfh_scalar_kern(*args)
|
106
|
+
lgt_form = jnp.log10(t_form)
|
107
|
+
|
108
|
+
lg_q_dt = 10**qlglgdt
|
109
|
+
qkern_inputs = lg_qt, lg_q_dt, lg_drop, lg_rejuv
|
110
|
+
qfunc = _quenching_kern(lgt_form, *qkern_inputs)
|
111
|
+
sfr = qfunc * ms_sfr
|
112
|
+
sfr = lax.cond(sfr < SFR_MIN, lambda x: SFR_MIN, lambda x: x, sfr)
|
113
|
+
return sfr
|
114
|
+
|
115
|
+
kern_with_tobs_loop = _get_kern_with_tobs_loop(_kern, tobs_loop)
|
116
|
+
sfh_from_mah_kern = _get_kern_with_galpop_loop(kern_with_tobs_loop, galpop_loop)
|
117
|
+
|
118
|
+
return sfh_from_mah_kern
|
119
|
+
|
120
|
+
|
121
|
+
def _get_kern_with_tobs_loop(kern, tobs_loop):
|
122
|
+
if tobs_loop == "vmap":
|
123
|
+
_t = [
|
124
|
+
0,
|
125
|
+
*[None] * N_MAH_PARAMS,
|
126
|
+
*[None] * N_MS_PARAMS,
|
127
|
+
*[None] * N_Q_PARAMS,
|
128
|
+
None,
|
129
|
+
None,
|
130
|
+
]
|
131
|
+
new_kern = jjit(vmap(kern, in_axes=_t))
|
132
|
+
elif tobs_loop == "scan":
|
133
|
+
|
134
|
+
@jjit
|
135
|
+
def new_kern(
|
136
|
+
tarr,
|
137
|
+
logmp,
|
138
|
+
logtc,
|
139
|
+
early_index,
|
140
|
+
late_index,
|
141
|
+
lgmcrit,
|
142
|
+
lgy_at_mcrit,
|
143
|
+
indx_lo,
|
144
|
+
indx_hi,
|
145
|
+
tau_dep,
|
146
|
+
lg_qt,
|
147
|
+
qlglgdt,
|
148
|
+
lg_drop,
|
149
|
+
lg_rejuv,
|
150
|
+
lgt0,
|
151
|
+
fb,
|
152
|
+
):
|
153
|
+
@jjit
|
154
|
+
def scan_func_time_array(carryover, el):
|
155
|
+
t_form = el
|
156
|
+
sfr_at_t_form = kern(
|
157
|
+
t_form,
|
158
|
+
logmp,
|
159
|
+
logtc,
|
160
|
+
early_index,
|
161
|
+
late_index,
|
162
|
+
lgmcrit,
|
163
|
+
lgy_at_mcrit,
|
164
|
+
indx_lo,
|
165
|
+
indx_hi,
|
166
|
+
tau_dep,
|
167
|
+
lg_qt,
|
168
|
+
qlglgdt,
|
169
|
+
lg_drop,
|
170
|
+
lg_rejuv,
|
171
|
+
lgt0,
|
172
|
+
fb,
|
173
|
+
)
|
174
|
+
carryover = sfr_at_t_form
|
175
|
+
accumulated = sfr_at_t_form
|
176
|
+
return carryover, accumulated
|
177
|
+
|
178
|
+
scan_init = 0.0
|
179
|
+
scan_arr = tarr
|
180
|
+
res = lax.scan(scan_func_time_array, scan_init, scan_arr)
|
181
|
+
sfh = res[1]
|
182
|
+
return sfh
|
183
|
+
|
184
|
+
elif tobs_loop is None:
|
185
|
+
new_kern = kern
|
186
|
+
else:
|
187
|
+
msg = "Input `tobs_loop`={0} must be either `vmap` or `scan`"
|
188
|
+
raise ValueError(msg.format(tobs_loop))
|
189
|
+
return new_kern
|
190
|
+
|
191
|
+
|
192
|
+
def _get_kern_with_galpop_loop(kern, galpop_loop):
|
193
|
+
if galpop_loop == "vmap":
|
194
|
+
_g = [
|
195
|
+
None,
|
196
|
+
*[0] * N_MAH_PARAMS,
|
197
|
+
*[0] * N_MS_PARAMS,
|
198
|
+
*[0] * N_Q_PARAMS,
|
199
|
+
None,
|
200
|
+
None,
|
201
|
+
]
|
202
|
+
new_kern = jjit(vmap(kern, in_axes=_g))
|
203
|
+
elif galpop_loop == "scan":
|
204
|
+
|
205
|
+
@jjit
|
206
|
+
def new_kern(
|
207
|
+
t,
|
208
|
+
logmp,
|
209
|
+
logtc,
|
210
|
+
early_index,
|
211
|
+
late_index,
|
212
|
+
lgmcrit,
|
213
|
+
lgy_at_mcrit,
|
214
|
+
indx_lo,
|
215
|
+
indx_hi,
|
216
|
+
tau_dep,
|
217
|
+
lg_qt,
|
218
|
+
qlglgdt,
|
219
|
+
lg_drop,
|
220
|
+
lg_rejuv,
|
221
|
+
lgt0,
|
222
|
+
fb,
|
223
|
+
):
|
224
|
+
n_gals = logmp.shape[0]
|
225
|
+
|
226
|
+
n_params = N_MAH_PARAMS + N_MS_PARAMS + N_Q_PARAMS
|
227
|
+
galpop_params = jnp.zeros(shape=(n_gals, n_params))
|
228
|
+
galpop_params = galpop_params.at[:, 0].set(logmp)
|
229
|
+
galpop_params = galpop_params.at[:, 1].set(logtc)
|
230
|
+
galpop_params = galpop_params.at[:, 2].set(early_index)
|
231
|
+
galpop_params = galpop_params.at[:, 3].set(late_index)
|
232
|
+
|
233
|
+
galpop_params = galpop_params.at[:, 4].set(lgmcrit)
|
234
|
+
galpop_params = galpop_params.at[:, 5].set(lgy_at_mcrit)
|
235
|
+
galpop_params = galpop_params.at[:, 6].set(indx_lo)
|
236
|
+
galpop_params = galpop_params.at[:, 7].set(indx_hi)
|
237
|
+
galpop_params = galpop_params.at[:, 8].set(tau_dep)
|
238
|
+
|
239
|
+
galpop_params = galpop_params.at[:, 9].set(lg_qt)
|
240
|
+
galpop_params = galpop_params.at[:, 10].set(qlglgdt)
|
241
|
+
galpop_params = galpop_params.at[:, 11].set(lg_drop)
|
242
|
+
galpop_params = galpop_params.at[:, 12].set(lg_rejuv)
|
243
|
+
|
244
|
+
@jjit
|
245
|
+
def scan_func_galpop(carryover, el):
|
246
|
+
params = el
|
247
|
+
mah_params = params[:N_MAH_PARAMS]
|
248
|
+
i, j = N_MAH_PARAMS, N_MAH_PARAMS + N_MS_PARAMS
|
249
|
+
ms_params = params[i:j]
|
250
|
+
i = N_MAH_PARAMS + N_MS_PARAMS
|
251
|
+
q_params = params[i:]
|
252
|
+
sfh_galpop = kern(t, *mah_params, *ms_params, *q_params, lgt0, fb)
|
253
|
+
carryover = sfh_galpop
|
254
|
+
accumulated = sfh_galpop
|
255
|
+
return carryover, accumulated
|
256
|
+
|
257
|
+
scan_init = jnp.zeros_like(t)
|
258
|
+
scan_arr = galpop_params
|
259
|
+
res = lax.scan(scan_func_galpop, scan_init, scan_arr)
|
260
|
+
sfh_galpop = res[1]
|
261
|
+
return sfh_galpop
|
262
|
+
|
263
|
+
elif galpop_loop is None:
|
264
|
+
new_kern = kern
|
265
|
+
else:
|
266
|
+
msg = "Input `galpop_loop`={0} must be either `vmap` or `scan`"
|
267
|
+
raise ValueError(msg.format(galpop_loop))
|
268
|
+
return new_kern
|
@@ -1,10 +1,12 @@
|
|
1
1
|
"""
|
2
2
|
"""
|
3
|
+
|
3
4
|
from collections import OrderedDict, namedtuple
|
4
5
|
|
5
6
|
import numpy as np
|
6
7
|
from diffmah.defaults import MAH_K
|
7
8
|
from diffmah.individual_halo_assembly import (
|
9
|
+
_calc_halo_history,
|
8
10
|
_calc_halo_history_scalar,
|
9
11
|
_rolling_plaw_vs_logt,
|
10
12
|
)
|
@@ -14,7 +16,7 @@ from jax import numpy as jnp
|
|
14
16
|
from jax import vmap
|
15
17
|
|
16
18
|
from ..utils import _inverse_sigmoid, _jax_get_dt_array, _sigmoid
|
17
|
-
from .gas_consumption import _gas_conversion_kern
|
19
|
+
from .gas_consumption import _gas_conversion_kern, _vmap_gas_conversion_kern
|
18
20
|
|
19
21
|
DEFAULT_MS_PDICT = OrderedDict(
|
20
22
|
lgmcrit=12.0,
|
@@ -53,7 +55,7 @@ MS_BOUNDING_SIGMOID_PDICT = calculate_sigmoid_bounds(MS_PARAM_BOUNDS_PDICT)
|
|
53
55
|
|
54
56
|
|
55
57
|
@jjit
|
56
|
-
def
|
58
|
+
def _lax_ms_sfh_scalar_kern_scan(t_form, mah_params, ms_params, lgt0, fb, t_table):
|
57
59
|
logmp, logtc, early, late = mah_params
|
58
60
|
all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
|
59
61
|
lgt_form = jnp.log10(t_form)
|
@@ -93,6 +95,38 @@ def _lax_ms_sfh_scalar_kern(t_form, mah_params, ms_params, lgt0, fb, t_table):
|
|
93
95
|
return sfr
|
94
96
|
|
95
97
|
|
98
|
+
@jjit
|
99
|
+
def _lax_ms_sfh_scalar_kern_sum(t_form, mah_params, ms_params, lgt0, fb, t_table):
|
100
|
+
logmp, logtc, early, late = mah_params
|
101
|
+
all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
|
102
|
+
lgt_form = jnp.log10(t_form)
|
103
|
+
log_mah_at_tform = _rolling_plaw_vs_logt(lgt_form, *all_mah_params)
|
104
|
+
|
105
|
+
sfr_eff_params = ms_params[:4]
|
106
|
+
sfr_eff = _sfr_eff_plaw(log_mah_at_tform, *sfr_eff_params)
|
107
|
+
|
108
|
+
tau_dep = ms_params[4]
|
109
|
+
tau_dep_max = MS_BOUNDING_SIGMOID_PDICT["tau_dep"][3]
|
110
|
+
|
111
|
+
# compute inst. gas accretion
|
112
|
+
lgtacc = jnp.log10(t_table)
|
113
|
+
res = _calc_halo_history(lgtacc, *all_mah_params)
|
114
|
+
dmhdt_at_tacc, log_mah_at_tacc = res
|
115
|
+
dmgdt_inst = fb * dmhdt_at_tacc
|
116
|
+
|
117
|
+
# compute the consumption kernel
|
118
|
+
dt = t_table[1] - t_table[0]
|
119
|
+
kern = _vmap_gas_conversion_kern(t_form, t_table, dt, tau_dep, tau_dep_max)
|
120
|
+
|
121
|
+
# convolve
|
122
|
+
dmgas_dt = jnp.sum(dmgdt_inst * kern * dt)
|
123
|
+
sfr = dmgas_dt * sfr_eff
|
124
|
+
return sfr
|
125
|
+
|
126
|
+
|
127
|
+
_lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_sum
|
128
|
+
|
129
|
+
|
96
130
|
@jjit
|
97
131
|
def _sfr_eff_plaw(lgm, lgmcrit, lgy_at_mcrit, indx_lo, indx_hi):
|
98
132
|
"""Instantaneous baryon conversion efficiency of main sequence galaxies
|
@@ -78,6 +78,9 @@ def _quenching_kern_u_params(lgt, u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv):
|
|
78
78
|
def _quenching_kern(lgt, lg_qt, lg_q_dt, q_drop, q_rejuv):
|
79
79
|
"""Base-10 logarithmic drop and symmetric rise in SFR over a time interval.
|
80
80
|
|
81
|
+
Note the relationship between the input lg_q_dt and q_params[1]=qlglgdt:
|
82
|
+
lg_q_dt = 10**qlglgdt
|
83
|
+
|
81
84
|
Parameters
|
82
85
|
----------
|
83
86
|
lgt : ndarray
|