diffstar 0.2.4__tar.gz → 0.3.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (104) hide show
  1. {diffstar-0.2.4 → diffstar-0.3.1}/.github/workflows/linting.yml +1 -1
  2. diffstar-0.3.1/.github/workflows/monthly-warning-test.yml +55 -0
  3. {diffstar-0.2.4 → diffstar-0.3.1}/.github/workflows/test_releases.yml +1 -1
  4. {diffstar-0.2.4 → diffstar-0.3.1}/.github/workflows/tests_cron.yml +3 -2
  5. {diffstar-0.2.4 → diffstar-0.3.1}/CHANGES.rst +10 -0
  6. {diffstar-0.2.4/diffstar.egg-info → diffstar-0.3.1}/PKG-INFO +2 -2
  7. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/__init__.py +5 -2
  8. diffstar-0.3.1/diffstar/_version.py +1 -0
  9. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/defaults.py +44 -0
  10. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/gas_consumption.py +6 -0
  11. diffstar-0.3.1/diffstar/kernels/history_kernel_builders.py +268 -0
  12. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/main_sequence_kernels.py +36 -2
  13. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/quenching_kernels.py +3 -0
  14. diffstar-0.3.1/diffstar/kernels/tests/test_kernel_builders.py +161 -0
  15. diffstar-0.3.1/diffstar/sfh_model.py +140 -0
  16. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_defaults.py +26 -0
  17. diffstar-0.3.1/diffstar/tests/test_main_sequence_kernels.py +28 -0
  18. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_sfh.py +1 -1
  19. diffstar-0.3.1/diffstar/tests/test_sfh_model.py +156 -0
  20. diffstar-0.3.1/diffstar/tests/test_utils.py +72 -0
  21. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/utils.py +80 -2
  22. {diffstar-0.2.4 → diffstar-0.3.1/diffstar.egg-info}/PKG-INFO +2 -2
  23. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/SOURCES.txt +5 -0
  24. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/demo_diffstar_fitter.ipynb +8 -0
  25. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/demo_diffstar_sfh.ipynb +21 -25
  26. {diffstar-0.2.4 → diffstar-0.3.1}/pyproject.toml +1 -1
  27. diffstar-0.2.4/diffstar/_version.py +0 -1
  28. diffstar-0.2.4/diffstar/kernels/tests/test_kernel_builders.py +0 -8
  29. diffstar-0.2.4/diffstar/tests/test_utils.py +0 -24
  30. {diffstar-0.2.4 → diffstar-0.3.1}/.coveragerc +0 -0
  31. {diffstar-0.2.4 → diffstar-0.3.1}/.git_archival.txt +0 -0
  32. {diffstar-0.2.4 → diffstar-0.3.1}/.gitattributes +0 -0
  33. {diffstar-0.2.4 → diffstar-0.3.1}/.gitignore +0 -0
  34. {diffstar-0.2.4 → diffstar-0.3.1}/.readthedocs.yml +0 -0
  35. {diffstar-0.2.4 → diffstar-0.3.1}/LICENSE.rst +0 -0
  36. {diffstar-0.2.4 → diffstar-0.3.1}/README.md +0 -0
  37. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/__init__.py +0 -0
  38. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/load_bpl.py +0 -0
  39. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/load_smah_data.py +0 -0
  40. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/tests/__init__.py +0 -0
  41. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
  42. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/__init__.py +0 -0
  43. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/fit_smah_helpers.py +0 -0
  44. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/fitting_kernels.py +0 -0
  45. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/param_clippers.py +0 -0
  46. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/stars.py +0 -0
  47. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/__init__.py +0 -0
  48. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +0 -0
  49. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +0 -0
  50. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +0 -0
  51. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_param_clippers.py +0 -0
  52. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_stars.py +0 -0
  53. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/fitting_helpers/utils.py +0 -0
  54. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/__init__.py +0 -0
  55. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/kernel_builders.py +0 -0
  56. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/__init__.py +0 -0
  57. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -0
  58. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/test_quenching_kernels.py +0 -0
  59. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
  60. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
  61. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
  62. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
  63. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
  64. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
  65. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
  66. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/sfh.py +0 -0
  67. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/__init__.py +0 -0
  68. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_gas.py +0 -0
  69. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_lax_main_sequence.py +0 -0
  70. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_lax_sfh.py +0 -0
  71. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/test_quenching.py +0 -0
  72. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
  73. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
  74. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
  75. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
  76. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
  77. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
  78. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
  79. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
  80. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
  81. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
  82. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
  83. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
  84. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
  85. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
  86. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
  87. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/dependency_links.txt +0 -0
  88. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/requires.txt +0 -0
  89. {diffstar-0.2.4 → diffstar-0.3.1}/diffstar.egg-info/top_level.txt +0 -0
  90. {diffstar-0.2.4 → diffstar-0.3.1}/docs/Makefile +0 -0
  91. {diffstar-0.2.4 → diffstar-0.3.1}/docs/make.bat +0 -0
  92. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/_static/README.txt +0 -0
  93. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/citation.rst +0 -0
  94. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/conf.py +0 -0
  95. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/demos.rst +0 -0
  96. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/index.rst +0 -0
  97. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/installation.rst +0 -0
  98. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/reference.rst +0 -0
  99. {diffstar-0.2.4 → diffstar-0.3.1}/docs/source/rtd_environment.yaml +0 -0
  100. {diffstar-0.2.4 → diffstar-0.3.1}/requirements.txt +0 -0
  101. {diffstar-0.2.4 → diffstar-0.3.1}/scripts/generate_unit_testing_data.py +0 -0
  102. {diffstar-0.2.4 → diffstar-0.3.1}/scripts/history_fitting_script.py +0 -0
  103. {diffstar-0.2.4 → diffstar-0.3.1}/setup.cfg +0 -0
  104. {diffstar-0.2.4 → diffstar-0.3.1}/setup.py +0 -0
@@ -16,7 +16,7 @@ jobs:
16
16
 
17
17
  - uses: conda-incubator/setup-miniconda@v2
18
18
  with:
19
- python-version: 3.9
19
+ python-version: 3.11
20
20
  channels: conda-forge,defaults
21
21
  channel-priority: strict
22
22
  show-channel-urls: true
@@ -0,0 +1,55 @@
1
+ name: Test for Warnings
2
+
3
+ on:
4
+ workflow_dispatch: null
5
+ schedule:
6
+ # Runs "First of every month at 3:15am Central"
7
+ - cron: '15 8 1 * *'
8
+
9
+ jobs:
10
+ tests:
11
+ name: tests
12
+ runs-on: "ubuntu-latest"
13
+
14
+ steps:
15
+ - uses: actions/checkout@v2
16
+ with:
17
+ fetch-depth: 0
18
+
19
+ - uses: conda-incubator/setup-miniconda@v2
20
+ with:
21
+ python-version: 3.11
22
+ channels: conda-forge,defaults
23
+ channel-priority: strict
24
+ show-channel-urls: true
25
+ miniforge-version: latest
26
+ miniforge-variant: Mambaforge
27
+ use-mamba: true
28
+
29
+ - name: configure conda and install code
30
+ # Test against current main branch of diffmah and dsps
31
+ shell: bash -l {0}
32
+ run: |
33
+ conda config --set always_yes yes
34
+ mamba install --quiet \
35
+ --file=requirements.txt
36
+ python -m pip install --no-deps -e .
37
+ mamba install -y -q \
38
+ flake8 \
39
+ pytest \
40
+ pytest-xdist \
41
+ pytest-cov \
42
+ pip \
43
+ setuptools \
44
+ "setuptools_scm>=7,<8" \
45
+ python-build
46
+ pip uninstall diffmah --yes
47
+ pip install --no-deps git+https://github.com/ArgonneCPAC/diffmah.git
48
+ pip install --no-deps git+https://github.com/ArgonneCPAC/dsps.git
49
+ python -m pip install --no-build-isolation --no-deps -e .
50
+
51
+ - name: test that no warnings are raised
52
+ shell: bash -l {0}
53
+ run: |
54
+ export PYTHONWARNINGS=error
55
+ pytest -v diffstar --cov --cov-report=xml
@@ -22,7 +22,7 @@ jobs:
22
22
 
23
23
  - uses: conda-incubator/setup-miniconda@v2
24
24
  with:
25
- python-version: 3.9
25
+ python-version: 3.11
26
26
  channels: conda-forge,defaults
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
@@ -22,7 +22,7 @@ jobs:
22
22
 
23
23
  - uses: conda-incubator/setup-miniconda@v2
24
24
  with:
25
- python-version: 3.9
25
+ python-version: 3.11
26
26
  channels: conda-forge,defaults
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
@@ -31,7 +31,7 @@ jobs:
31
31
  use-mamba: true
32
32
 
33
33
  - name: configure conda and install code
34
- # Test against current main branch of diffmah
34
+ # Test against current main branch of diffmah and dsps
35
35
  shell: bash -l {0}
36
36
  run: |
37
37
  conda config --set always_yes yes
@@ -49,6 +49,7 @@ jobs:
49
49
  python-build
50
50
  pip uninstall diffmah --yes
51
51
  pip install --no-deps git+https://github.com/ArgonneCPAC/diffmah.git
52
+ pip install --no-deps git+https://github.com/ArgonneCPAC/dsps.git
52
53
  python -m pip install --no-build-isolation --no-deps -e .
53
54
 
54
55
  - name: test
@@ -1,3 +1,13 @@
1
+ 0.3.1 (2024-6-19)
2
+ ------------------
3
+ - Performance improvements for calc_sfh_galpop and calc_sfh_singlegal
4
+
5
+
6
+ 0.3.0 (2024-01-17)
7
+ ------------------
8
+ - Implement new API for primary user-facing functions calc_sfh_galpop and calc_sfh_singlegal
9
+
10
+
1
11
  0.2.4 (2024-01-16)
2
12
  ------------------
3
13
  - Require diffmah>=0.5.0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diffstar
3
- Version: 0.2.4
3
+ Version: 0.3.1
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,7 +35,7 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.8
38
+ Requires-Python: >=3.9
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
41
  Requires-Dist: diffmah>=0.5.0
@@ -7,8 +7,11 @@ from .defaults import (
7
7
  DEFAULT_DIFFSTAR_U_PARAMS,
8
8
  DiffstarParams,
9
9
  DiffstarUParams,
10
+ MSParams,
11
+ MSUParams,
12
+ QParams,
13
+ QUParams,
10
14
  get_bounded_diffstar_params,
11
15
  get_unbounded_diffstar_params,
12
16
  )
13
- from .kernels import get_ms_sfh_from_mah_kern, get_sfh_from_mah_kern
14
- from .sfh import sfh_galpop, sfh_singlegal
17
+ from .sfh_model import calc_sfh_galpop, calc_sfh_singlegal
@@ -0,0 +1 @@
1
+ __version__ = '0.3.1'
@@ -14,6 +14,7 @@ LGT0 = np.log10(TODAY)
14
14
  # Constants related to SFH integrals
15
15
  SFR_MIN = 1e-14
16
16
  T_BIRTH_MIN = 0.001
17
+ T_TABLE_MIN = 0.01
17
18
  N_T_LGSM_INTEGRATION = 100
18
19
  DEFAULT_N_STEPS = 50
19
20
 
@@ -53,6 +54,28 @@ DEFAULT_DIFFSTAR_U_PARAMS = DiffstarUParams(DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PAR
53
54
 
54
55
  @jjit
55
56
  def get_bounded_diffstar_params(diffstar_u_params):
57
+ """Calculate diffstar parameters from unbounded counterparts.
58
+
59
+ The returned diffstar_params is the input expected by diffstar.calc_sfh_singlegal
60
+ and diffstar.calc_sfh_galpop.
61
+
62
+ Parameters
63
+ ----------
64
+ diffstar_u_params : namedtuple, length 2
65
+ DiffstarUParams = u_ms_params, u_q_params
66
+ u_ms_params and u_q_params are tuples of floats or ndarrays
67
+ u_ms_params = u_lgmcrit, u_lgy_at_mcrit, u_indx_lo, u_indx_hi, u_tau_dep
68
+ u_q_params = u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv
69
+
70
+ Returns
71
+ -------
72
+ diffstar_params : namedtuple, length 2
73
+ DiffstarParams = ms_params, q_params
74
+ ms_params and q_params are tuples of floats or ndarrays
75
+ ms_params = lgmcrit, lgy_at_mcrit, indx_lo, indx_hi, tau_dep
76
+ q_params = lg_qt, qlglgdt, lg_drop, lg_rejuv
77
+
78
+ """
56
79
  ms_params = MSParams(*_get_bounded_sfr_params(*diffstar_u_params.u_ms_params))
57
80
  q_params = QParams(*_get_bounded_q_params(*diffstar_u_params.u_q_params))
58
81
  return DiffstarParams(ms_params, q_params)
@@ -60,6 +83,27 @@ def get_bounded_diffstar_params(diffstar_u_params):
60
83
 
61
84
  @jjit
62
85
  def get_unbounded_diffstar_params(diffstar_params):
86
+ """Calculate unbounded diffstar parameters from standard params.
87
+
88
+ This is the inverse function to get_bounded_diffstar_params
89
+
90
+ Parameters
91
+ ----------
92
+ diffstar_params : namedtuple, length 2
93
+ DiffstarParams = ms_params, q_params
94
+ ms_params and q_params are tuples of floats or ndarrays
95
+ ms_params = lgmcrit, lgy_at_mcrit, indx_lo, indx_hi, tau_dep
96
+ q_params = lg_qt, qlglgdt, lg_drop, lg_rejuv
97
+
98
+ Returns
99
+ -------
100
+ diffstar_u_params : namedtuple, length 2
101
+ DiffstarUParams = u_ms_params, u_q_params
102
+ u_ms_params and u_q_params are tuples of floats or ndarrays
103
+ u_ms_params = u_lgmcrit, u_lgy_at_mcrit, u_indx_lo, u_indx_hi, u_tau_dep
104
+ u_q_params = u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv
105
+
106
+ """
63
107
  u_ms_params = MSUParams(*_get_unbounded_sfr_params(*diffstar_params.ms_params))
64
108
  u_q_params = QUParams(*_get_unbounded_q_params(*diffstar_params.q_params))
65
109
  return DiffstarUParams(u_ms_params, u_q_params)
@@ -1,5 +1,6 @@
1
1
  """Calculate mass of gas available for star formation from the freshly accreted gas.
2
2
  """
3
+
3
4
  from jax import jit as jjit
4
5
  from jax import lax
5
6
  from jax import numpy as jnp
@@ -28,6 +29,11 @@ def _gas_conversion_kern(t_form, t_acc, dt, tau_dep, tau_dep_max):
28
29
  return tri_kern
29
30
 
30
31
 
32
+ _vmap_gas_conversion_kern = jjit(
33
+ vmap(_gas_conversion_kern, in_axes=(None, 0, None, None, None))
34
+ )
35
+
36
+
31
37
  _a, _b = (0, None, 0, None, None), (None, 0, None, None, None)
32
38
  _depletion_kernel = jjit(vmap(vmap(_gas_conversion_kern, in_axes=_b), in_axes=_a))
33
39
 
@@ -0,0 +1,268 @@
1
+ """
2
+ """
3
+
4
+ from jax import jit as jjit
5
+ from jax import lax
6
+ from jax import numpy as jnp
7
+ from jax import vmap
8
+
9
+ from ..defaults import (
10
+ DEFAULT_DIFFSTAR_PARAMS,
11
+ DEFAULT_MAH_PARAMS,
12
+ DEFAULT_N_STEPS,
13
+ SFR_MIN,
14
+ T_BIRTH_MIN,
15
+ )
16
+ from .main_sequence_kernels import (
17
+ _lax_ms_sfh_scalar_kern_scan,
18
+ _lax_ms_sfh_scalar_kern_sum,
19
+ )
20
+ from .quenching_kernels import _quenching_kern
21
+
22
+ __all__ = ("build_sfh_from_mah_kernel",)
23
+
24
+ N_MAH_PARAMS = len(DEFAULT_MAH_PARAMS)
25
+ N_MS_PARAMS = len(DEFAULT_DIFFSTAR_PARAMS.ms_params)
26
+ N_Q_PARAMS = len(DEFAULT_DIFFSTAR_PARAMS.q_params)
27
+
28
+
29
+ def build_sfh_from_mah_kernel(
30
+ n_steps=DEFAULT_N_STEPS,
31
+ tacc_integration_min=T_BIRTH_MIN,
32
+ tobs_loop=None,
33
+ galpop_loop=None,
34
+ tform_loop="sum",
35
+ ):
36
+ """Build a JAX-jitted kernel to calculate SFHs of a galaxy population.
37
+
38
+ Parameters
39
+ ----------
40
+ n_steps : int, optional
41
+ Number of timesteps to use in the tacc integration
42
+
43
+ tacc_integration_min : float, optional
44
+ Earliest time to use in the tacc integrations. Default is 0.01 Gyr.
45
+
46
+ tobs_loop : string, optional
47
+ Argument specifies whether the input time of observation is a scalar or array
48
+ Default argument is None, for a JAX kernel that assumes scalar input for tobs
49
+ For a JAX kernel that assumes an array input for tobs,
50
+ options are either 'vmap' or 'scan', specifying the calculation method
51
+
52
+ galpop_loop : string, optional
53
+ Argument specifies whether the input galaxy/halo parameters assumed by the
54
+ returned JAX kernel pertain to a single galaxy or a population.
55
+ Default argument is None, for a single-galaxy JAX kernel
56
+ For a JAX kernel that assumes galaxy population,
57
+ options are either 'vmap' or 'scan', specifying the calculation method
58
+
59
+ tform_loop : string
60
+ Use 'sum' for faster vmap-based calculation and 'scan' for slower alternative.
61
+ Default is 'sum'
62
+
63
+ Returns
64
+ -------
65
+ sfh_from_mah_kern : function
66
+ JAX-jitted function that calculates SFH in accord with the input arguments
67
+ Function signature is as follows:
68
+
69
+ def sfh_from_mah_kern(t, mah_params, ms_params, q_params, lgt0, fb):
70
+ return sfh
71
+
72
+ """
73
+ if tform_loop == "sum":
74
+ _lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_sum
75
+ elif tform_loop == "scan":
76
+ _lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_scan
77
+
78
+ uniform_table = jnp.linspace(0, 1, n_steps)
79
+
80
+ @jjit
81
+ def _kern(
82
+ t_form,
83
+ logmp,
84
+ logtc,
85
+ early_index,
86
+ late_index,
87
+ lgmcrit,
88
+ lgy_at_mcrit,
89
+ indx_lo,
90
+ indx_hi,
91
+ tau_dep,
92
+ lg_qt,
93
+ qlglgdt,
94
+ lg_drop,
95
+ lg_rejuv,
96
+ lgt0,
97
+ fb,
98
+ ):
99
+ mah_params = logmp, logtc, early_index, late_index
100
+ ms_params = lgmcrit, lgy_at_mcrit, indx_lo, indx_hi, tau_dep
101
+
102
+ t_min = jnp.max(jnp.array((tacc_integration_min, t_form - tau_dep)))
103
+ t_table = t_min + uniform_table * (t_form - t_min)
104
+ args = t_form, mah_params, ms_params, lgt0, fb, t_table
105
+ ms_sfr = _lax_ms_sfh_scalar_kern(*args)
106
+ lgt_form = jnp.log10(t_form)
107
+
108
+ lg_q_dt = 10**qlglgdt
109
+ qkern_inputs = lg_qt, lg_q_dt, lg_drop, lg_rejuv
110
+ qfunc = _quenching_kern(lgt_form, *qkern_inputs)
111
+ sfr = qfunc * ms_sfr
112
+ sfr = lax.cond(sfr < SFR_MIN, lambda x: SFR_MIN, lambda x: x, sfr)
113
+ return sfr
114
+
115
+ kern_with_tobs_loop = _get_kern_with_tobs_loop(_kern, tobs_loop)
116
+ sfh_from_mah_kern = _get_kern_with_galpop_loop(kern_with_tobs_loop, galpop_loop)
117
+
118
+ return sfh_from_mah_kern
119
+
120
+
121
+ def _get_kern_with_tobs_loop(kern, tobs_loop):
122
+ if tobs_loop == "vmap":
123
+ _t = [
124
+ 0,
125
+ *[None] * N_MAH_PARAMS,
126
+ *[None] * N_MS_PARAMS,
127
+ *[None] * N_Q_PARAMS,
128
+ None,
129
+ None,
130
+ ]
131
+ new_kern = jjit(vmap(kern, in_axes=_t))
132
+ elif tobs_loop == "scan":
133
+
134
+ @jjit
135
+ def new_kern(
136
+ tarr,
137
+ logmp,
138
+ logtc,
139
+ early_index,
140
+ late_index,
141
+ lgmcrit,
142
+ lgy_at_mcrit,
143
+ indx_lo,
144
+ indx_hi,
145
+ tau_dep,
146
+ lg_qt,
147
+ qlglgdt,
148
+ lg_drop,
149
+ lg_rejuv,
150
+ lgt0,
151
+ fb,
152
+ ):
153
+ @jjit
154
+ def scan_func_time_array(carryover, el):
155
+ t_form = el
156
+ sfr_at_t_form = kern(
157
+ t_form,
158
+ logmp,
159
+ logtc,
160
+ early_index,
161
+ late_index,
162
+ lgmcrit,
163
+ lgy_at_mcrit,
164
+ indx_lo,
165
+ indx_hi,
166
+ tau_dep,
167
+ lg_qt,
168
+ qlglgdt,
169
+ lg_drop,
170
+ lg_rejuv,
171
+ lgt0,
172
+ fb,
173
+ )
174
+ carryover = sfr_at_t_form
175
+ accumulated = sfr_at_t_form
176
+ return carryover, accumulated
177
+
178
+ scan_init = 0.0
179
+ scan_arr = tarr
180
+ res = lax.scan(scan_func_time_array, scan_init, scan_arr)
181
+ sfh = res[1]
182
+ return sfh
183
+
184
+ elif tobs_loop is None:
185
+ new_kern = kern
186
+ else:
187
+ msg = "Input `tobs_loop`={0} must be either `vmap` or `scan`"
188
+ raise ValueError(msg.format(tobs_loop))
189
+ return new_kern
190
+
191
+
192
+ def _get_kern_with_galpop_loop(kern, galpop_loop):
193
+ if galpop_loop == "vmap":
194
+ _g = [
195
+ None,
196
+ *[0] * N_MAH_PARAMS,
197
+ *[0] * N_MS_PARAMS,
198
+ *[0] * N_Q_PARAMS,
199
+ None,
200
+ None,
201
+ ]
202
+ new_kern = jjit(vmap(kern, in_axes=_g))
203
+ elif galpop_loop == "scan":
204
+
205
+ @jjit
206
+ def new_kern(
207
+ t,
208
+ logmp,
209
+ logtc,
210
+ early_index,
211
+ late_index,
212
+ lgmcrit,
213
+ lgy_at_mcrit,
214
+ indx_lo,
215
+ indx_hi,
216
+ tau_dep,
217
+ lg_qt,
218
+ qlglgdt,
219
+ lg_drop,
220
+ lg_rejuv,
221
+ lgt0,
222
+ fb,
223
+ ):
224
+ n_gals = logmp.shape[0]
225
+
226
+ n_params = N_MAH_PARAMS + N_MS_PARAMS + N_Q_PARAMS
227
+ galpop_params = jnp.zeros(shape=(n_gals, n_params))
228
+ galpop_params = galpop_params.at[:, 0].set(logmp)
229
+ galpop_params = galpop_params.at[:, 1].set(logtc)
230
+ galpop_params = galpop_params.at[:, 2].set(early_index)
231
+ galpop_params = galpop_params.at[:, 3].set(late_index)
232
+
233
+ galpop_params = galpop_params.at[:, 4].set(lgmcrit)
234
+ galpop_params = galpop_params.at[:, 5].set(lgy_at_mcrit)
235
+ galpop_params = galpop_params.at[:, 6].set(indx_lo)
236
+ galpop_params = galpop_params.at[:, 7].set(indx_hi)
237
+ galpop_params = galpop_params.at[:, 8].set(tau_dep)
238
+
239
+ galpop_params = galpop_params.at[:, 9].set(lg_qt)
240
+ galpop_params = galpop_params.at[:, 10].set(qlglgdt)
241
+ galpop_params = galpop_params.at[:, 11].set(lg_drop)
242
+ galpop_params = galpop_params.at[:, 12].set(lg_rejuv)
243
+
244
+ @jjit
245
+ def scan_func_galpop(carryover, el):
246
+ params = el
247
+ mah_params = params[:N_MAH_PARAMS]
248
+ i, j = N_MAH_PARAMS, N_MAH_PARAMS + N_MS_PARAMS
249
+ ms_params = params[i:j]
250
+ i = N_MAH_PARAMS + N_MS_PARAMS
251
+ q_params = params[i:]
252
+ sfh_galpop = kern(t, *mah_params, *ms_params, *q_params, lgt0, fb)
253
+ carryover = sfh_galpop
254
+ accumulated = sfh_galpop
255
+ return carryover, accumulated
256
+
257
+ scan_init = jnp.zeros_like(t)
258
+ scan_arr = galpop_params
259
+ res = lax.scan(scan_func_galpop, scan_init, scan_arr)
260
+ sfh_galpop = res[1]
261
+ return sfh_galpop
262
+
263
+ elif galpop_loop is None:
264
+ new_kern = kern
265
+ else:
266
+ msg = "Input `galpop_loop`={0} must be either `vmap` or `scan`"
267
+ raise ValueError(msg.format(galpop_loop))
268
+ return new_kern
@@ -1,10 +1,12 @@
1
1
  """
2
2
  """
3
+
3
4
  from collections import OrderedDict, namedtuple
4
5
 
5
6
  import numpy as np
6
7
  from diffmah.defaults import MAH_K
7
8
  from diffmah.individual_halo_assembly import (
9
+ _calc_halo_history,
8
10
  _calc_halo_history_scalar,
9
11
  _rolling_plaw_vs_logt,
10
12
  )
@@ -14,7 +16,7 @@ from jax import numpy as jnp
14
16
  from jax import vmap
15
17
 
16
18
  from ..utils import _inverse_sigmoid, _jax_get_dt_array, _sigmoid
17
- from .gas_consumption import _gas_conversion_kern
19
+ from .gas_consumption import _gas_conversion_kern, _vmap_gas_conversion_kern
18
20
 
19
21
  DEFAULT_MS_PDICT = OrderedDict(
20
22
  lgmcrit=12.0,
@@ -53,7 +55,7 @@ MS_BOUNDING_SIGMOID_PDICT = calculate_sigmoid_bounds(MS_PARAM_BOUNDS_PDICT)
53
55
 
54
56
 
55
57
  @jjit
56
- def _lax_ms_sfh_scalar_kern(t_form, mah_params, ms_params, lgt0, fb, t_table):
58
+ def _lax_ms_sfh_scalar_kern_scan(t_form, mah_params, ms_params, lgt0, fb, t_table):
57
59
  logmp, logtc, early, late = mah_params
58
60
  all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
59
61
  lgt_form = jnp.log10(t_form)
@@ -93,6 +95,38 @@ def _lax_ms_sfh_scalar_kern(t_form, mah_params, ms_params, lgt0, fb, t_table):
93
95
  return sfr
94
96
 
95
97
 
98
+ @jjit
99
+ def _lax_ms_sfh_scalar_kern_sum(t_form, mah_params, ms_params, lgt0, fb, t_table):
100
+ logmp, logtc, early, late = mah_params
101
+ all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
102
+ lgt_form = jnp.log10(t_form)
103
+ log_mah_at_tform = _rolling_plaw_vs_logt(lgt_form, *all_mah_params)
104
+
105
+ sfr_eff_params = ms_params[:4]
106
+ sfr_eff = _sfr_eff_plaw(log_mah_at_tform, *sfr_eff_params)
107
+
108
+ tau_dep = ms_params[4]
109
+ tau_dep_max = MS_BOUNDING_SIGMOID_PDICT["tau_dep"][3]
110
+
111
+ # compute inst. gas accretion
112
+ lgtacc = jnp.log10(t_table)
113
+ res = _calc_halo_history(lgtacc, *all_mah_params)
114
+ dmhdt_at_tacc, log_mah_at_tacc = res
115
+ dmgdt_inst = fb * dmhdt_at_tacc
116
+
117
+ # compute the consumption kernel
118
+ dt = t_table[1] - t_table[0]
119
+ kern = _vmap_gas_conversion_kern(t_form, t_table, dt, tau_dep, tau_dep_max)
120
+
121
+ # convolve
122
+ dmgas_dt = jnp.sum(dmgdt_inst * kern * dt)
123
+ sfr = dmgas_dt * sfr_eff
124
+ return sfr
125
+
126
+
127
+ _lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_sum
128
+
129
+
96
130
  @jjit
97
131
  def _sfr_eff_plaw(lgm, lgmcrit, lgy_at_mcrit, indx_lo, indx_hi):
98
132
  """Instantaneous baryon conversion efficiency of main sequence galaxies
@@ -78,6 +78,9 @@ def _quenching_kern_u_params(lgt, u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv):
78
78
  def _quenching_kern(lgt, lg_qt, lg_q_dt, q_drop, q_rejuv):
79
79
  """Base-10 logarithmic drop and symmetric rise in SFR over a time interval.
80
80
 
81
+ Note the relationship between the input lg_q_dt and q_params[1]=qlglgdt:
82
+ lg_q_dt = 10**qlglgdt
83
+
81
84
  Parameters
82
85
  ----------
83
86
  lgt : ndarray